From 81c642f0716083f89fb914bf67984d90ac468c91 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 16 Dec 2023 19:31:17 +0100 Subject: [PATCH 001/116] initial-commit --- docs/source/en/_toctree.yml | 2 + docs/source/en/model_doc/mamba.md | 118 +++ src/transformers/__init__.py | 16 + src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 3 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/auto/tokenization_auto.py | 1 + .../models/flava/modeling_flava.py | 2 +- src/transformers/models/mamba/__init__.py | 60 ++ .../models/mamba/configuration_mamba.py | 122 +++ .../mamba/convert_mamba_checkpoint_to_hf.py | 201 ++++ .../models/mamba/modeling_mamba.py | 859 ++++++++++++++++++ tests/models/mamba/__init__.py | 0 tests/models/mamba/test_modeling_mamba.py | 452 +++++++++ 14 files changed, 1840 insertions(+), 1 deletion(-) create mode 100644 docs/source/en/model_doc/mamba.md create mode 100644 src/transformers/models/mamba/__init__.py create mode 100644 src/transformers/models/mamba/configuration_mamba.py create mode 100644 src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py create mode 100644 src/transformers/models/mamba/modeling_mamba.py create mode 100644 tests/models/mamba/__init__.py create mode 100644 tests/models/mamba/test_modeling_mamba.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 57ab9769b60f80..b4eebe62364d8d 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -386,6 +386,8 @@ title: M2M100 - local: model_doc/madlad-400 title: MADLAD-400 + - local: model_doc/mamba + title: Mamba - local: model_doc/marian title: MarianMT - local: model_doc/markuplm diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md new file mode 100644 index 00000000000000..0a56eca918cc6e --- /dev/null +++ b/docs/source/en/model_doc/mamba.md @@ -0,0 +1,118 @@ + + +# Mamba + +# Mamba + +# Mamba + +## Overview + +The Mamba model was proposed in []() by . + + +The abstract from the paper is the following: + +** + +Tips: + + + +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). +The original code can be found [here](). + + +## MambaConfig + +[[autodoc]] MambaConfig + +## MambaModel + +[[autodoc]] MambaModel + - forward + +## MambaLMHeadModel + +[[autodoc]] MambaForCausalLM + - forward + +## Mamba attention and the recurrent formulas + +In a traditional auto-regressive Transformer, attention is written as + +$$O = \hbox{softmax}(QK^{T} / \sqrt{d}) V$$ + +with \\(Q\\), \\(K\\) and \\(V\\) are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product \\(QK^{T}\\) then has shape `seq_len x seq_len` and we can take the maxtrix product with \\(V\\) to get the output \\(O\\) of the same shape as the others. + +Replacing the softmax by its value gives: + +$$O_{i} = \frac{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}} V_{j}}{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}}}$$ + +Note that the entries in \\(QK^{T}\\) corresponding to \\(j > i\\) are masked (the sum stops at j) because the attention is not allowed to look at future tokens (only past ones). + +In comparison, the MAMBA attention is given by + +$$O_{i} = \sigma(R_{i}) \frac{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}} V_{j}}{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}}}$$ + +where \\(R\\) is a new matrix called receptance by the author, \\(K\\) and \\(V\\) are still the key and value (\\(\sigma\\) here is the sigmoid function). \\(W\\) is a new vector that represents the position of the token and is given by + +$$W_{0} = u \hbox{ and } W_{k} = (k-1)w \hbox{ for } k \geq 1$$ + +with \\(u\\) and \\(w\\) learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them \\(N_{i}\\) and \\(D_{i}\\) we have: + +$$N_{i} = e^{u + K_{i}} V_{i} + \hat{N}_{i} \hbox{ where } \hat{N}_{i} = e^{K_{i-1}} V_{i-1} + e^{w + K_{i-2}} V_{i-2} \cdots + e^{(i-2)w + K_{1}} V_{1}$$ + +so \\(\hat{N}_{i}\\) (called `numerator_state` in the code) satistfies + +$$\hat{N}_{0} = 0 \hbox{ and } \hat{N}_{j+1} = e^{K_{j}} V_{j} + e^{w} \hat{N}_{j}$$ + +and + +$$D_{i} = e^{u + K_{i}} + \hat{D}_{i} \hbox{ where } \hat{D}_{i} = e^{K_{i-1}} + e^{w + K_{i-2}} \cdots + e^{(i-2)w + K_{1}}$$ + +so \\(\hat{D}_{i}\\) (called `denominator_state` in the code) satistfies + +$$\hat{D}_{0} = 0 \hbox{ and } \hat{D}_{j+1} = e^{K_{j}} + e^{w} \hat{D}_{j}$$ + +The actual recurrent formula used are a tiny bit more complex, as for numerical stability we don't want to compute exponentials of big numbers. Usually the softmax is not computed as is, but the exponential of the maximum term is divided of the numerator and denominator: + +$$\frac{e^{x_{i}}}{\sum_{j=1}^{n} e^{x_{j}}} = \frac{e^{x_{i} - M}}{\sum_{j=1}^{n} e^{x_{j} - M}}$$ + +with \\(M\\) the maximum of all \\(x_{j}\\). So here on top of saving the numerator state (\\(\hat{N}\\)) and the denominator state (\\(\hat{D}\\)) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use + +$$\tilde{N}_{i} = e^{-M_{i}} \hat{N}_{i} \hbox{ and } \tilde{D}_{i} = e^{-M_{i}} \hat{D}_{i}$$ + +defined by the following recurrent formulas: + +$$\tilde{N}_{0} = 0 \hbox{ and } \tilde{N}_{j+1} = e^{K_{j} - q} V_{j} + e^{w + M_{j} - q} \tilde{N}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ + +and + +$$\tilde{D}_{0} = 0 \hbox{ and } \tilde{D}_{j+1} = e^{K_{j} - q} + e^{w + M_{j} - q} \tilde{D}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ + +and \\(M_{j+1} = q\\). With those, we can then compute + +$$N_{i} = e^{u + K_{i} - q} V_{i} + e^{M_{i}} \tilde{N}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ + +and + +$$D_{i} = e^{u + K_{i} - q} + e^{M_{i}} \tilde{D}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ + +which finally gives us + +$$O_{i} = \sigma(R_{i}) \frac{N_{i}}{D_{i}}$$ \ No newline at end of file diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 98139511d239c5..c5f901e6ee1ba2 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -735,6 +735,7 @@ "RoFormerTokenizer", ], "models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"], + "models.mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig"], "models.sam": [ "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP", "SamConfig", @@ -3078,6 +3079,14 @@ "RwkvPreTrainedModel", ] ) + _import_structure["models.mamba"].extend( + [ + "MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST", + "MambaForCausalLM", + "MambaModel", + "MambaPreTrainedModel", + ] + ) _import_structure["models.sam"].extend( [ "SAM_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -5396,6 +5405,7 @@ RoFormerTokenizer, ) from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig + from .models.mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig from .models.sam import ( SAM_PRETRAINED_CONFIG_ARCHIVE_MAP, SamConfig, @@ -7428,6 +7438,12 @@ RwkvModel, RwkvPreTrainedModel, ) + from .models.mamba import ( + MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST, + MambaForCausalLM, + MambaModel, + MambaPreTrainedModel, + ) from .models.sam import ( SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 319c8499319a3f..f15ac6f2b2b39d 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -186,6 +186,7 @@ roc_bert, roformer, rwkv, + mamba, sam, seamless_m4t, seamless_m4t_v2, diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index b91226ac877897..9945b82b3f56fc 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -193,6 +193,7 @@ ("roc_bert", "RoCBertConfig"), ("roformer", "RoFormerConfig"), ("rwkv", "RwkvConfig"), + ("mamba", "MambaConfig"), ("sam", "SamConfig"), ("seamless_m4t", "SeamlessM4TConfig"), ("seamless_m4t_v2", "SeamlessM4Tv2Config"), @@ -411,6 +412,7 @@ ("roc_bert", "ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("rwkv", "RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("seamless_m4t", "SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("seamless_m4t_v2", "SEAMLESS_M4T_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -655,6 +657,7 @@ ("roc_bert", "RoCBert"), ("roformer", "RoFormer"), ("rwkv", "RWKV"), + ("mamba", "Mamba"), ("sam", "SAM"), ("seamless_m4t", "SeamlessM4T"), ("seamless_m4t_v2", "SeamlessM4Tv2"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index e562bd28bdb3f3..32a48ef1c4a84a 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -186,6 +186,7 @@ ("roc_bert", "RoCBertModel"), ("roformer", "RoFormerModel"), ("rwkv", "RwkvModel"), + ("mamba", "MambaModel"), ("sam", "SamModel"), ("seamless_m4t", "SeamlessM4TModel"), ("seamless_m4t_v2", "SeamlessM4Tv2Model"), @@ -291,6 +292,7 @@ ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), ("roc_bert", "RoCBertForPreTraining"), ("rwkv", "RwkvForCausalLM"), + ("mamba", "MambaForCausalLM"), ("splinter", "SplinterForPreTraining"), ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), @@ -380,6 +382,7 @@ ("roc_bert", "RoCBertForMaskedLM"), ("roformer", "RoFormerForMaskedLM"), ("rwkv", "RwkvForCausalLM"), + ("mamba", "MambaForCausalLM"), ("speech_to_text", "Speech2TextForConditionalGeneration"), ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), @@ -453,6 +456,7 @@ ("roc_bert", "RoCBertForCausalLM"), ("roformer", "RoFormerForCausalLM"), ("rwkv", "RwkvForCausalLM"), + ("mamba", "MambaForCausalLM"), ("speech_to_text_2", "Speech2Text2ForCausalLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("trocr", "TrOCRForCausalLM"), diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index 9e4066de99a5f9..8874390a5d118f 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -347,6 +347,7 @@ ("roc_bert", ("RoCBertTokenizer", None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), ("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), + ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ( "seamless_m4t", ( diff --git a/src/transformers/models/flava/modeling_flava.py b/src/transformers/models/flava/modeling_flava.py index 64ede9c89ed88b..25af3dfb4c5dd3 100644 --- a/src/transformers/models/flava/modeling_flava.py +++ b/src/transformers/models/flava/modeling_flava.py @@ -1949,7 +1949,7 @@ def forward( if mim_labels is not None: mim_labels = mim_labels[pos_mask] - + # MMM Image Loss if multimodal_masked_embeddings is not None and self.mmm_image_weight > 0: sequence_for_image = multimodal_masked_embeddings diff --git a/src/transformers/models/mamba/__init__.py b/src/transformers/models/mamba/__init__.py new file mode 100644 index 00000000000000..7a629ac396672a --- /dev/null +++ b/src/transformers/models/mamba/__init__.py @@ -0,0 +1,60 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, +) + + +_import_structure = { + "configuration_mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig", "MambaOnnxConfig"], +} + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_mamba"] = [ + "MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST", + "MambaForCausalLM", + "MambaModel", + "MambaPreTrainedModel", + ] + + +if TYPE_CHECKING: + from .configuration_mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig, MambaOnnxConfig + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_mamba import ( + MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST, + MambaForCausalLM, + MambaModel, + MambaPreTrainedModel, + ) +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py new file mode 100644 index 00000000000000..2ad117f4bd78e5 --- /dev/null +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MAMBA configuration""" + +from ...configuration_utils import PretrainedConfig +from ...utils import logging + + +logger = logging.get_logger(__name__) + +MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "state-spaces/mamba-2.8b": "https://huggingface.co/state-spaces/mamba-2.8b/resolve/main/config.json", +} + + + +class MambaConfig(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the RWVK-4 + [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50277): + Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`MambaModel`]. + context_length (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model can be be used with in a single forward (using it in RNN mode + lets use any sequence length). + hidden_size (`int`, *optional*, defaults to 4096): + Dimensionality of the embeddings and hidden states. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the model. + attention_hidden_size (`int`, *optional*): + Dimensionality of the attention hidden states. Will default to `hidden_size` if unset. + intermediate_size (`int`, *optional*): + Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer + as GPTNeoX. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as + GPTNeoX. + rescale_every (`int`, *optional*, defaults to 6): + At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every + `rescale_every` layer. If set to 0 or a negative number, no rescale is done. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to tie the word embeddings with the input token embeddings. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last state. + + + Example: + + ```python + >>> from transformers import MambaConfig, MambaModel + + >>> # Initializing a Mamba configuration + >>> configuration = MambaConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = MambaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "mamba" + attribute_map = {"max_position_embeddings": "context_length"} + + def __init__( + self, + vocab_size=50277, + context_length=1024, + hidden_size=4096, + num_hidden_layers=32, + attention_hidden_size=None, + intermediate_size=None, + layer_norm_epsilon=1e-5, + bos_token_id=0, + eos_token_id=0, + rescale_every=6, + tie_word_embeddings=False, + use_cache=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.context_length = context_length + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size + self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size + self.layer_norm_epsilon = layer_norm_epsilon + self.rescale_every = rescale_every + self.use_cache = use_cache + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__( + tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs + ) diff --git a/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py b/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py new file mode 100644 index 00000000000000..914dab0229f3b6 --- /dev/null +++ b/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py @@ -0,0 +1,201 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert a MAMBA checkpoint from BlinkDL to the Hugging Face format.""" + + +import argparse +import gc +import json +import os +import re + +import torch +from huggingface_hub import hf_hub_download + +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, MambaConfig +from transformers.modeling_utils import WEIGHTS_INDEX_NAME, shard_checkpoint + + +NUM_HIDDEN_LAYERS_MAPPING = { + "169M": 12, + "430M": 24, + "1B5": 24, + "3B": 32, + "7B": 32, + "14B": 40, +} + +HIDEN_SIZE_MAPPING = { + "169M": 768, + "430M": 1024, + "1B5": 2048, + "3B": 2560, + "7B": 4096, + "14B": 5120, +} + + +def convert_state_dict(state_dict): + state_dict_keys = list(state_dict.keys()) + for name in state_dict_keys: + weight = state_dict.pop(name) + # emb -> embedding + if name.startswith("emb."): + name = name.replace("emb.", "embeddings.") + # ln_0 -> pre_ln (only present at block 0) + if name.startswith("blocks.0.ln0"): + name = name.replace("blocks.0.ln0", "blocks.0.pre_ln") + # att -> attention + name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name) + # ffn -> feed_forward + name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name) + # time_mix_k -> time_mix_key and reshape + if name.endswith(".time_mix_k"): + name = name.replace(".time_mix_k", ".time_mix_key") + # time_mix_v -> time_mix_value and reshape + if name.endswith(".time_mix_v"): + name = name.replace(".time_mix_v", ".time_mix_value") + # time_mix_r -> time_mix_key and reshape + if name.endswith(".time_mix_r"): + name = name.replace(".time_mix_r", ".time_mix_receptance") + + if name != "head.weight": + name = "mamba." + name + + state_dict[name] = weight + return state_dict + + +def convert_rmkv_checkpoint_to_hf_format( + repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None +): + # 1. If possible, build the tokenizer. + if tokenizer_file is None: + print("No `--tokenizer_file` provided, we will use the default tokenizer.") + vocab_size = 50277 + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + else: + tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) + vocab_size = len(tokenizer) + tokenizer.save_pretrained(output_dir) + + # 2. Build the config + possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys()) + if size is None: + # Try to infer size from the checkpoint name + for candidate in possible_sizes: + if candidate in checkpoint_file: + size = candidate + break + if size is None: + raise ValueError("Could not infer the size, please provide it with the `--size` argument.") + if size not in possible_sizes: + raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.") + + config = MambaConfig( + vocab_size=vocab_size, + num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size], + hidden_size=HIDEN_SIZE_MAPPING[size], + ) + config.save_pretrained(output_dir) + + # 3. Download model file then convert state_dict + model_file = hf_hub_download(repo_id, checkpoint_file) + state_dict = torch.load(model_file, map_location="cpu") + state_dict = convert_state_dict(state_dict) + + # 4. Split in shards and save + shards, index = shard_checkpoint(state_dict) + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(output_dir, shard_file)) + + if index is not None: + save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + + # 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict + print( + "Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model." + ) + shard_files = list(shards.keys()) + + del state_dict + del shards + gc.collect() + + for shard_file in shard_files: + state_dict = torch.load(os.path.join(output_dir, shard_file)) + torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file)) + + del state_dict + gc.collect() + + if push_to_hub: + if model_name is None: + raise ValueError("Please provide a `model_name` to push the model to the Hub.") + model = AutoModelForCausalLM.from_pretrained(output_dir) + model.push_to_hub(model_name, max_shard_size="2GB") + tokenizer.push_to_hub(model_name) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # Required parameters + parser.add_argument( + "--repo_id", default=None, type=str, required=True, help="Repo ID from which to pull the checkpoint." + ) + parser.add_argument( + "--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo." + ) + parser.add_argument( + "--output_dir", default=None, type=str, required=True, help="Where to save the converted model." + ) + parser.add_argument( + "--tokenizer_file", + default=None, + type=str, + help="Path to the tokenizer file to use (if not provided, only the model is converted).", + ) + parser.add_argument( + "--size", + default=None, + type=str, + help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.", + ) + parser.add_argument( + "--push_to_hub", + action="store_true", + help="Push to the Hub the converted model.", + ) + parser.add_argument( + "--model_name", + default=None, + type=str, + help="Name of the pushed model on the Hub, including the username / organization.", + ) + + args = parser.parse_args() + convert_rmkv_checkpoint_to_hf_format( + args.repo_id, + args.checkpoint_file, + args.output_dir, + size=args.size, + tokenizer_file=args.tokenizer_file, + push_to_hub=args.push_to_hub, + model_name=args.model_name, + ) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py new file mode 100644 index 00000000000000..9ba8c840b07d29 --- /dev/null +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -0,0 +1,859 @@ +# coding=utf-8 +# Copyright 2023 Bo Peng and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MAMBA model.""" + +import math +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +from ...modeling_utils import PreTrainedModel +from ...utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_bitsandbytes_available, + is_ninja_available, + is_torch_cuda_available, + logging, +) +from .configuration_mamba import MambaConfig + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "state-spaces/mamba-2.8b" +_CONFIG_FOR_DOC = "MambaConfig" + +MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "state-spaces/mamba-2.8b", + # See all Mamba models at https://huggingface.co/models?filter=mamba +] + + + +mamba_cuda_kernel = None + + +# Copied from transformers.models.rwkv.modeling_rwkv.load_wkv_cuda_kernel with RWKV->MAMBA,rwkv->mamba +def load_wkv_cuda_kernel(context_length): + from torch.utils.cpp_extension import load as load_kernel + + global mamba_cuda_kernel + + kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mamba" + cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]] + + # Only load the kernel if it's not been loaded yet or if we changed the context length + if mamba_cuda_kernel is not None and mamba_cuda_kernel.max_seq_length == context_length: + return + + logger.info(f"Loading CUDA kernel for MAMBA at context length of {context_length}.") + + flags = [ + "-res-usage", + "--maxrregcount 60", + "--use_fast_math", + "-O3", + "-Xptxas -O3", + "--extra-device-vectorization", + f"-DTmax={context_length}", + ] + mamba_cuda_kernel = load_kernel( + name=f"wkv_{context_length}", + sources=cuda_kernel_files, + verbose=(logging.get_verbosity() == logging.DEBUG), + extra_cuda_cflags=flags, + ) + mamba_cuda_kernel.max_seq_length = context_length + + +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvLinearAttention with Rwkv->Mamba,rwkv->mamba +class MambaLinearAttention(torch.autograd.Function): + @staticmethod + def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False): + batch_size, seq_len, hidden_size = key.size() + if seq_len > mamba_cuda_kernel.max_seq_length: + raise ValueError( + f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of " + f"{mamba_cuda_kernel.max_seq_length} with this model." + ) + if batch_size * hidden_size % min(hidden_size, 32) != 0: + raise ValueError( + f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round " + f"multiple of {min(hidden_size, 32)}." + ) + + ctx.input_dtype = key.dtype + + if ( + time_decay.device.type != "cuda" + or time_first.device.type != "cuda" + or key.device.type != "cuda" + or value.device.type != "cuda" + ): + raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.") + + time_decay = -torch.exp(time_decay.float().contiguous()) + if key.dtype == torch.float16: + time_first = time_first.float() + key = key.float() + value = value.float() + time_first = time_first.contiguous() + key = key.contiguous() + value = value.contiguous() + # The CUDA kernel will fill this tensor. + output = torch.empty_like(key, memory_format=torch.contiguous_format) + if return_state or state is not None: + if state is None: + state = torch.zeros( + batch_size, + hidden_size, + 3, + dtype=torch.float32, + device=key.device, + memory_format=torch.contiguous_format, + ) + state[:, :, 2] -= 1e38 + else: + state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous() + if key.dtype == torch.bfloat16: + forward_func = mamba_cuda_kernel.forward_with_state_bf16 + else: + forward_func = mamba_cuda_kernel.forward_with_state + forward_func(time_decay, time_first, key, value, output, state) + else: + forward_func = mamba_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else mamba_cuda_kernel.forward + forward_func(time_decay, time_first, key, value, output) + + ctx.save_for_backward(time_decay, time_first, key, value, output) + + if state is not None: + state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)] + + return output.to(ctx.input_dtype), state + + @staticmethod + # g stands for grad + def backward(ctx, g_output, g_state=None): + input_dtype = ctx.input_dtype + + time_decay, time_first, key, value, output = ctx.saved_tensors + # The CUDA kernel will fill those tensors. + g_time_decay = torch.empty_like( + time_decay, + memory_format=torch.contiguous_format, + dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32, + ) + g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format) + g_key = torch.empty_like(key, memory_format=torch.contiguous_format) + g_value = torch.empty_like(value, memory_format=torch.contiguous_format) + + if input_dtype == torch.float16: + g_output = g_output.float() + backward_func = mamba_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else mamba_cuda_kernel.backward + backward_func( + time_decay, + time_first, + key, + value, + output, + g_output.contiguous(), + g_time_decay, + g_time_first, + g_key, + g_value, + ) + + return ( + g_time_decay.to(input_dtype), + g_time_first.to(input_dtype), + g_key.to(input_dtype), + g_value.to(input_dtype), + None, + None, + ) + + +# Copied from transformers.models.rwkv.modeling_rwkv.rwkv_linear_attention_cpu with rwkv->mamba +def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False): + # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed + # within a torch.no_grad. + _, seq_length, _ = key.size() + output = torch.zeros_like(key) + + if state is None: + num_state = torch.zeros_like(key[:, 0], dtype=torch.float32) + den_state = torch.zeros_like(key[:, 0], dtype=torch.float32) + max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38 + else: + num_state, den_state, max_state = state + # For numerical stability + # real_numerator_state = num_state * torch.exp(max_state) + # real_denominator_state = den_state * torch.exp(max_state) + + time_decay = -torch.exp(time_decay) + + for current_index in range(seq_length): + current_key = key[:, current_index].float() + current_value = value[:, current_index] + + # wkv computation at time t + max_for_output = torch.maximum(max_state, current_key + time_first) + e1 = torch.exp(max_state - max_for_output) + e2 = torch.exp(current_key + time_first - max_for_output) + numerator = e1 * num_state + e2 * current_value + denominator = e1 * den_state + e2 + output[:, current_index] = (numerator / denominator).to(output.dtype) + + # Update state for next iteration + max_for_state = torch.maximum(max_state + time_decay, current_key) + e1 = torch.exp(max_state + time_decay - max_for_state) + e2 = torch.exp(current_key - max_for_state) + num_state = e1 * num_state + e2 * current_value + den_state = e1 * den_state + e2 + max_state = max_for_state + + if return_state or state is not None: + state = [num_state, den_state, max_state] + + return output, state + + +# Copied from transformers.models.rwkv.modeling_rwkv.rwkv_linear_attention with Rwkv->Mamba,rwkv->mamba +def mamba_linear_attention(time_decay, time_first, key, value, state=None, return_state=False): + no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value]) + # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version + # in this case). + one_token = key.size(1) == 1 + if mamba_cuda_kernel is None or no_cuda or one_token: + return mamba_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state) + else: + return MambaLinearAttention.apply(time_decay, time_first, key, value, state, return_state) + + +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvSelfAttention with RWKV->MAMBA,Rwkv->Mamba,rwkv->mamba +class MambaSelfAttention(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.config = config + kernel_loaded = mamba_cuda_kernel is not None and mamba_cuda_kernel.max_seq_length == config.context_length + if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded: + try: + load_wkv_cuda_kernel(config.context_length) + except Exception: + logger.info("Could not load the custom CUDA kernel for MAMBA attention.") + self.layer_id = layer_id + hidden_size = config.hidden_size + attention_hidden_size = ( + config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size + ) + self.attention_hidden_size = attention_hidden_size + + self.time_decay = nn.Parameter(torch.empty(attention_hidden_size)) + self.time_first = nn.Parameter(torch.empty(attention_hidden_size)) + + self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size)) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False) + self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False) + + # TODO: maybe jit, otherwise move inside forward + def extract_key_value(self, hidden, state=None): + # Mix hidden with the previous timestep to produce key, value, receptance + if hidden.size(1) == 1 and state is not None: + shifted = state[1][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[1][:, :, self.layer_id] + key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) + value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value) + receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) + + key = self.key(key) + value = self.value(value) + receptance = torch.sigmoid(self.receptance(receptance)) + if state is not None: + state[1][:, :, self.layer_id] = hidden[:, -1] + return receptance, key, value, state + + def forward(self, hidden, state=None, use_cache=False): + receptance, key, value, state = self.extract_key_value(hidden, state=state) + layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None + mamba, layer_state = mamba_linear_attention( + self.time_decay, + self.time_first, + key, + value, + state=layer_state, + return_state=use_cache, + ) + + if layer_state is not None: + state[2][:, :, self.layer_id] = layer_state[0] + state[3][:, :, self.layer_id] = layer_state[1] + state[4][:, :, self.layer_id] = layer_state[2] + + return self.output(receptance * mamba), state + + +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvFeedForward with Rwkv->Mamba +class MambaFeedForward(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.config = config + self.layer_id = layer_id + hidden_size = config.hidden_size + intermediate_size = ( + config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size + ) + + self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) + self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size)) + self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size)) + + self.key = nn.Linear(hidden_size, intermediate_size, bias=False) + self.receptance = nn.Linear(hidden_size, hidden_size, bias=False) + self.value = nn.Linear(intermediate_size, hidden_size, bias=False) + + def forward(self, hidden, state=None): + if hidden.size(1) == 1 and state is not None: + shifted = state[0][:, :, self.layer_id] + else: + shifted = self.time_shift(hidden) + if state is not None: + shifted[:, 0] = state[0][:, :, self.layer_id] + key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) + receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) + + key = torch.square(torch.relu(self.key(key))) + value = self.value(key) + receptance = torch.sigmoid(self.receptance(receptance)) + + if state is not None: + state[0][:, :, self.layer_id] = hidden[:, -1] + + return receptance * value, state + + +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvBlock with Rwkv->Mamba +class MambaBlock(nn.Module): + def __init__(self, config, layer_id): + super().__init__() + self.config = config + self.layer_id = layer_id + + if layer_id == 0: + self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.attention = MambaSelfAttention(config, layer_id) + self.feed_forward = MambaFeedForward(config, layer_id) + + def forward(self, hidden, state=None, use_cache=False, output_attentions=False): + if self.layer_id == 0: + hidden = self.pre_ln(hidden) + + attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache) + hidden = hidden + attention + + feed_forward, state = self.feed_forward(self.ln2(hidden), state=state) + hidden = hidden + feed_forward + + outputs = (hidden, state) + if output_attentions: + outputs += (attention,) + else: + outputs += (None,) + + return outputs + + +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvPreTrainedModel with Rwkv->Mamba,rwkv->mamba +class MambaPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = MambaConfig + base_model_prefix = "mamba" + _no_split_modules = ["MambaBlock"] + _keep_in_fp32_modules = ["time_decay", "time_first"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, MambaSelfAttention): + layer_id = module.layer_id + num_hidden_layers = module.config.num_hidden_layers + hidden_size = module.config.hidden_size + attention_hidden_size = module.attention_hidden_size + + ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 + + time_weight = torch.tensor( + [i / hidden_size for i in range(hidden_size)], + dtype=module.time_mix_key.dtype, + device=module.time_mix_key.device, + ) + time_weight = time_weight[None, None, :] + + decay_speed = [ + -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) + for h in range(attention_hidden_size) + ] + decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device) + zigzag = ( + torch.tensor( + [(i + 1) % 3 - 1 for i in range(attention_hidden_size)], + dtype=module.time_first.dtype, + device=module.time_first.device, + ) + * 0.5 + ) + + with torch.no_grad(): + module.time_decay.data = decay_speed + module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag) + + module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) + module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 + module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0) + elif isinstance(module, MambaFeedForward): + layer_id = module.layer_id + num_hidden_layers = module.config.num_hidden_layers + hidden_size = module.config.hidden_size + + ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 + + time_weight = torch.tensor( + [i / hidden_size for i in range(hidden_size)], + dtype=module.time_mix_key.dtype, + device=module.time_mix_key.device, + ) + time_weight = time_weight[None, None, :] + + with torch.no_grad(): + module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) + module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) + + +@dataclass +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvOutput with RWKV->MAMBA,Rwkv->Mamba +class MambaOutput(ModelOutput): + """ + Class for the MAMBA model outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + state: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvCausalLMOutput with Rwkv->Mamba +class MambaCausalLMOutput(ModelOutput): + """ + Base class for causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to + avoid providing the old `input_ids`. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + state: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +MAMBA_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MambaConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +MAMBA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + This is currently not used by `MambaModel`, but will be supported in the future. + + [What are attention masks?](../glossary#attention-mask) + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*): + If passed along, the model uses the previous state in all the blocks (which will give the output for the + `input_ids` provided as if the model add `state_input_ids + input_ids` as context). + use_cache (`bool`, *optional*): + If set to `True`, the last state is returned and can be used to quickly generate the next logits. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare MAMBA Model transformer outputting raw hidden-states without any specific head on top.", + MAMBA_START_DOCSTRING, +) +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvModel with RWKV->MAMBA,Rwkv->Mamba +class MambaModel(MambaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) + self.blocks = nn.ModuleList([MambaBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)]) + self.ln_out = nn.LayerNorm(config.hidden_size) + + self.layers_are_rescaled = False + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embeddings + + def set_input_embeddings(self, new_embeddings): + self.embeddings = new_embeddings + + @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MambaOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + state: Optional[List[torch.FloatTensor]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MambaOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if self.training == self.layers_are_rescaled: + self._rescale_layers() + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is None and inputs_embeds is None: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.embeddings(input_ids) + + if use_cache and state is None: + shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) + state = [ + torch.zeros( + *shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device + ) + for i in range(5) + ] + state[4] -= 1e30 + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + hidden_states = inputs_embeds + + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + for idx, block in enumerate(self.blocks): + if self.gradient_checkpointing and self.training: + hidden_states, state, attentions = self._gradient_checkpointing_func( + block.__call__, hidden_states, state, use_cache, output_attentions + ) + else: + hidden_states, state, attentions = block( + hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions + ) + + if ( + self.layers_are_rescaled + and self.config.rescale_every > 0 + and (idx + 1) % self.config.rescale_every == 0 + ): + hidden_states = hidden_states / 2 + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if output_attentions: + all_self_attentions = all_self_attentions + (attentions,) + + hidden_states = self.ln_out(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None) + + return MambaOutput( + last_hidden_state=hidden_states, + state=state, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def _rescale_layers(self): + # Layers should be rescaled for inference only. + if self.layers_are_rescaled == (not self.training): + return + if self.config.rescale_every > 0: + with torch.no_grad(): + for block_id, block in enumerate(self.blocks): + if self.training: + block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every)) + block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every)) + else: + # Deal with quantization statistics + if hasattr(block.attention.output.weight, "SCB"): + block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) + block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) + elif hasattr(block.attention.output.weight, "quant_state"): + self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id) + self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id) + else: + block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every)) + block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every)) + + self.layers_are_rescaled = not self.training + + def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): + r""" + Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will + be quantized again. + """ + if not is_bitsandbytes_available(): + raise ImportError("Please install bitsandbytes to use this method.") + import bitsandbytes as bnb + + dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state) + + dequant_weights.div_(2 ** int(block_id // self.config.rescale_every)) + + # re-quantize the model: + # we need to put it first on CPU then back to the device + # this will create an overhead :/ + # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid + # bugs with bnb + quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device) + setattr(target_layer, "weight", quant_weight) + + +@add_start_docstrings( + """ + The MAMBA Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + MAMBA_START_DOCSTRING, +) +# Copied from transformers.models.rwkv.modeling_rwkv.RwkvForCausalLM with RWKV->MAMBA,Rwkv->Mamba,rwkv->mamba +class MambaForCausalLM(MambaPreTrainedModel): + _tied_weights_keys = ["head.weight"] + + def __init__(self, config): + super().__init__(config) + self.mamba = MambaModel(config) + self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.head + + def set_output_embeddings(self, new_embeddings): + self.head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): + # only last token for inputs_ids if the state is passed along. + if state is not None: + input_ids = input_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and state is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs["state"] = state + return model_inputs + + @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=MambaCausalLMOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, # noqa + inputs_embeds: Optional[torch.FloatTensor] = None, + state: Optional[List[torch.FloatTensor]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MambaCausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + mamba_outputs = self.mamba( + input_ids, + inputs_embeds=inputs_embeds, + state=state, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = mamba_outputs[0] + + logits = self.head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (logits,) + mamba_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return MambaCausalLMOutput( + loss=loss, + logits=logits, + state=mamba_outputs.state, + hidden_states=mamba_outputs.hidden_states, + attentions=mamba_outputs.attentions, + ) diff --git a/tests/models/mamba/__init__.py b/tests/models/mamba/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py new file mode 100644 index 00000000000000..7c361bb252113e --- /dev/null +++ b/tests/models/mamba/test_modeling_mamba.py @@ -0,0 +1,452 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from unittest.util import safe_repr + +from transformers import AutoTokenizer, MambaConfig, is_torch_available +from transformers.testing_utils import require_torch, slow, torch_device + +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_pipeline_mixin import PipelineTesterMixin + + +if is_torch_available(): + import torch + + from transformers import ( + MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST, + MambaForCausalLM, + MambaModel, + ) + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 +else: + is_torch_greater_or_equal_than_2_0 = False + + +class MambaModelTester: + def __init__( + self, + parent, + batch_size=14, + seq_length=7, + is_training=True, + use_token_type_ids=False, + use_input_mask=True, + use_labels=True, + use_mc_token_ids=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=2, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + num_labels=3, + num_choices=4, + scope=None, + ): + self.parent = parent + self.batch_size = batch_size + self.seq_length = seq_length + self.is_training = is_training + self.use_token_type_ids = use_token_type_ids + self.use_input_mask = use_input_mask + self.use_labels = use_labels + self.use_mc_token_ids = use_mc_token_ids + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.type_sequence_label_size = type_sequence_label_size + self.num_labels = num_labels + self.num_choices = num_choices + self.scope = scope + self.bos_token_id = vocab_size - 1 + self.eos_token_id = vocab_size - 1 + self.pad_token_id = vocab_size - 1 + + def get_large_model_config(self): + return MambaConfig.from_pretrained("sgugger/mamba-4-pile-7b") + + def prepare_config_and_inputs( + self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False + ): + input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) + + input_mask = None + if self.use_input_mask: + input_mask = random_attention_mask([self.batch_size, self.seq_length]) + + token_type_ids = None + if self.use_token_type_ids: + token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) + + mc_token_ids = None + if self.use_mc_token_ids: + mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) + + sequence_labels = None + token_labels = None + choice_labels = None + if self.use_labels: + sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) + token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = ids_tensor([self.batch_size], self.num_choices) + + config = self.get_config( + gradient_checkpointing=gradient_checkpointing, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + + return ( + config, + input_ids, + input_mask, + None, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) + + def get_config( + self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False + ): + return MambaConfig( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + num_hidden_layers=self.num_hidden_layers, + intermediate_size=self.intermediate_size, + activation_function=self.hidden_act, + resid_pdrop=self.hidden_dropout_prob, + attn_pdrop=self.attention_probs_dropout_prob, + n_positions=self.max_position_embeddings, + type_vocab_size=self.type_vocab_size, + use_cache=True, + bos_token_id=self.bos_token_id, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + gradient_checkpointing=gradient_checkpointing, + scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, + reorder_and_upcast_attn=reorder_and_upcast_attn, + ) + + def get_pipeline_config(self): + config = self.get_config() + config.vocab_size = 300 + return config + + def prepare_config_and_inputs_for_decoder(self): + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() + + encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) + encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) + + return ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + sequence_labels, + token_labels, + choice_labels, + encoder_hidden_states, + encoder_attention_mask, + ) + + def create_and_check_mamba_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + config.output_hidden_states = True + model = MambaModel(config=config) + model.to(torch_device) + model.eval() + + result = model(input_ids) + + self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) + self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1) + + def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = MambaForCausalLM(config) + model.to(torch_device) + model.eval() + + result = model(input_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + + def create_and_check_state_equivalency(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + model = MambaModel(config=config) + model.to(torch_device) + model.eval() + + outputs = model(input_ids) + output_whole = outputs.last_hidden_state + + outputs = model(input_ids[:, :2]) + output_one = outputs.last_hidden_state + + # Using the state computed on the first inputs, we will get the same output + outputs = model(input_ids[:, 2:], state=outputs.state) + output_two = outputs.last_hidden_state + + self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)) + + def create_and_check_forward_and_backwards( + self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + ): + model = MambaForCausalLM(config) + model.to(torch_device) + if gradient_checkpointing: + model.gradient_checkpointing_enable() + + result = model(input_ids, labels=input_ids) + self.parent.assertEqual(result.loss.shape, ()) + self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) + result.loss.backward() + + def prepare_config_and_inputs_for_common(self): + config_and_inputs = self.prepare_config_and_inputs() + + ( + config, + input_ids, + input_mask, + head_mask, + token_type_ids, + mc_token_ids, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + inputs_dict = {"input_ids": input_ids} + + return config, inputs_dict + + +@unittest.skipIf( + not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" +) +@require_torch +class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): + all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () + # all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else () + fx_compatible = False + test_missing_keys = False + test_model_parallel = False + test_pruning = False + test_head_masking = False # Mamba does not support head masking + + def setUp(self): + self.model_tester = MambaModelTester(self) + self.config_tester = ConfigTester( + self, config_class=MambaConfig, n_embd=37, common_properties=["hidden_size", "num_hidden_layers"] + ) + + def assertInterval(self, member, container, msg=None): + r""" + Simple utility function to check if a member is inside an interval. + """ + if isinstance(member, torch.Tensor): + max_value, min_value = member.max().item(), member.min().item() + elif isinstance(member, list) or isinstance(member, tuple): + max_value, min_value = max(member), min(member) + + if not isinstance(container, list): + raise TypeError("container should be a list or tuple") + elif len(container) != 2: + raise ValueError("container should have 2 elements") + + expected_min, expected_max = container + + is_inside_interval = (min_value >= expected_min) and (max_value <= expected_max) + + if not is_inside_interval: + standardMsg = "%s not found in %s" % (safe_repr(member), safe_repr(container)) + self.fail(self._formatMessage(msg, standardMsg)) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_mamba_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_mamba_model(*config_and_inputs) + + def test_mamba_lm_head_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_causl_lm(*config_and_inputs) + + def test_state_equivalency(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_state_equivalency(*config_and_inputs) + + def test_initialization(self): + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + model = model_class(config=config) + for name, param in model.named_parameters(): + if "time_decay" in name: + if param.requires_grad: + self.assertTrue(param.data.max().item() == 3.0) + self.assertTrue(param.data.min().item() == -5.0) + elif "time_first" in name: + if param.requires_grad: + # check if it's a ones like + self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) + elif any(x in name for x in ["time_mix_key", "time_mix_receptance"]): + if param.requires_grad: + self.assertInterval( + param.data, + [0.0, 1.0], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + elif "time_mix_value" in name: + if param.requires_grad: + self.assertInterval( + param.data, + [0.0, 1.3], + msg=f"Parameter {name} of model {model_class} seems not properly initialized", + ) + + def test_attention_outputs(self): + r""" + Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models + it has a shape `batch_size, seq_len, hidden_size`. + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + batch_size = inputs["input_ids"].shape[0] + with torch.no_grad(): + outputs = model(**inputs) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + batch_size = inputs["input_ids"].shape[0] + with torch.no_grad(): + outputs = model(**inputs) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [batch_size, seq_len, config.hidden_size], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + + inputs = self._prepare_for_class(inputs_dict, model_class) + batch_size = inputs["input_ids"].shape[0] + with torch.no_grad(): + outputs = model(**inputs) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [batch_size, seq_len, config.hidden_size], + ) + + @slow + def test_model_from_pretrained(self): + for model_name in MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: + model = MambaModel.from_pretrained(model_name) + self.assertIsNotNone(model) + + +@unittest.skipIf( + not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" +) +@slow +class MAMBAIntegrationTests(unittest.TestCase): + def setUp(self): + self.model_id = "state-spaces/mamba-2.8b" + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + + def test_simple_generate(self): + expected_output = "Hello my name is Jasmine and I am a newbie to the" + model = MambaForCausalLM.from_pretrained(self.model_id).to(torch_device) + + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) + output = model.generate(input_ids, max_new_tokens=10) + output_sentence = self.tokenizer.decode(output[0].tolist()) + + self.assertEqual(output_sentence, expected_output) + + def test_simple_generate_bf16(self): + expected_output = "Hello my name is Jasmine and I am a newbie to the" + + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) + model = MambaForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) + + output = model.generate(input_ids, max_new_tokens=10) + output_sentence = self.tokenizer.decode(output[0].tolist()) + + self.assertEqual(output_sentence, expected_output) From 00d3a6c1f0f35f9a0e0a88530fd3c5fd1f3f2db0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 1 Feb 2024 08:09:39 +0900 Subject: [PATCH 002/116] start cleaning --- .../models/mamba/modeling_mamba.py | 213 ++++-------------- 1 file changed, 42 insertions(+), 171 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 9ba8c840b07d29..c3340b63537375 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -54,14 +54,14 @@ mamba_cuda_kernel = None -# Copied from transformers.models.rwkv.modeling_rwkv.load_wkv_cuda_kernel with RWKV->MAMBA,rwkv->mamba -def load_wkv_cuda_kernel(context_length): +# Copied from transformers.models.mamba.modeling_mamba.load_mamba_cuda_kernel with mamba->MAMBA,mamba->mamba +def load_mamba_cuda_kernel(context_length): from torch.utils.cpp_extension import load as load_kernel global mamba_cuda_kernel kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mamba" - cuda_kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu", "wkv_cuda_bf16.cu"]] + cuda_kernel_files = [kernel_folder / f for f in ["mamba_op.cpp", "mamba_cuda.cu", "mamba_cuda_bf16.cu"]] # Only load the kernel if it's not been loaded yet or if we changed the context length if mamba_cuda_kernel is not None and mamba_cuda_kernel.max_seq_length == context_length: @@ -79,7 +79,7 @@ def load_wkv_cuda_kernel(context_length): f"-DTmax={context_length}", ] mamba_cuda_kernel = load_kernel( - name=f"wkv_{context_length}", + name=f"mamba_{context_length}", sources=cuda_kernel_files, verbose=(logging.get_verbosity() == logging.DEBUG), extra_cuda_cflags=flags, @@ -87,8 +87,7 @@ def load_wkv_cuda_kernel(context_length): mamba_cuda_kernel.max_seq_length = context_length -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvLinearAttention with Rwkv->Mamba,rwkv->mamba -class MambaLinearAttention(torch.autograd.Function): +class MambaMixer(torch.autograd.Function): @staticmethod def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False): batch_size, seq_len, hidden_size = key.size() @@ -111,7 +110,7 @@ def forward(ctx, time_decay, time_first, key, value, state=None, return_state=Fa or key.device.type != "cuda" or value.device.type != "cuda" ): - raise ValueError("Calling the CUDA kernel for wkv attention requires all tensors to be on CUDA devices.") + raise ValueError("Calling the CUDA kernel for mamba attention requires all tensors to be on CUDA devices.") time_decay = -torch.exp(time_decay.float().contiguous()) if key.dtype == torch.float16: @@ -194,7 +193,6 @@ def backward(ctx, g_output, g_state=None): ) -# Copied from transformers.models.rwkv.modeling_rwkv.rwkv_linear_attention_cpu with rwkv->mamba def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False): # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed # within a torch.no_grad. @@ -217,7 +215,7 @@ def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, r current_key = key[:, current_index].float() current_value = value[:, current_index] - # wkv computation at time t + # mamba computation at time t max_for_output = torch.maximum(max_state, current_key + time_first) e1 = torch.exp(max_state - max_for_output) e2 = torch.exp(current_key + time_first - max_for_output) @@ -239,7 +237,6 @@ def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, r return output, state -# Copied from transformers.models.rwkv.modeling_rwkv.rwkv_linear_attention with Rwkv->Mamba,rwkv->mamba def mamba_linear_attention(time_decay, time_first, key, value, state=None, return_state=False): no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value]) # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version @@ -251,131 +248,14 @@ def mamba_linear_attention(time_decay, time_first, key, value, state=None, retur return MambaLinearAttention.apply(time_decay, time_first, key, value, state, return_state) -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvSelfAttention with RWKV->MAMBA,Rwkv->Mamba,rwkv->mamba -class MambaSelfAttention(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.config = config - kernel_loaded = mamba_cuda_kernel is not None and mamba_cuda_kernel.max_seq_length == config.context_length - if is_ninja_available() and is_torch_cuda_available() and not kernel_loaded: - try: - load_wkv_cuda_kernel(config.context_length) - except Exception: - logger.info("Could not load the custom CUDA kernel for MAMBA attention.") - self.layer_id = layer_id - hidden_size = config.hidden_size - attention_hidden_size = ( - config.attention_hidden_size if config.attention_hidden_size is not None else hidden_size - ) - self.attention_hidden_size = attention_hidden_size - - self.time_decay = nn.Parameter(torch.empty(attention_hidden_size)) - self.time_first = nn.Parameter(torch.empty(attention_hidden_size)) - - self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size)) - self.time_mix_value = nn.Parameter(torch.empty(1, 1, hidden_size)) - self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size)) - - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.key = nn.Linear(hidden_size, attention_hidden_size, bias=False) - self.value = nn.Linear(hidden_size, attention_hidden_size, bias=False) - self.receptance = nn.Linear(hidden_size, attention_hidden_size, bias=False) - self.output = nn.Linear(attention_hidden_size, hidden_size, bias=False) - - # TODO: maybe jit, otherwise move inside forward - def extract_key_value(self, hidden, state=None): - # Mix hidden with the previous timestep to produce key, value, receptance - if hidden.size(1) == 1 and state is not None: - shifted = state[1][:, :, self.layer_id] - else: - shifted = self.time_shift(hidden) - if state is not None: - shifted[:, 0] = state[1][:, :, self.layer_id] - key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) - value = hidden * self.time_mix_value + shifted * (1 - self.time_mix_value) - receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) - - key = self.key(key) - value = self.value(value) - receptance = torch.sigmoid(self.receptance(receptance)) - if state is not None: - state[1][:, :, self.layer_id] = hidden[:, -1] - return receptance, key, value, state - - def forward(self, hidden, state=None, use_cache=False): - receptance, key, value, state = self.extract_key_value(hidden, state=state) - layer_state = tuple(s[:, :, self.layer_id] for s in state[2:]) if state is not None else None - mamba, layer_state = mamba_linear_attention( - self.time_decay, - self.time_first, - key, - value, - state=layer_state, - return_state=use_cache, - ) - - if layer_state is not None: - state[2][:, :, self.layer_id] = layer_state[0] - state[3][:, :, self.layer_id] = layer_state[1] - state[4][:, :, self.layer_id] = layer_state[2] - - return self.output(receptance * mamba), state - - -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvFeedForward with Rwkv->Mamba -class MambaFeedForward(nn.Module): - def __init__(self, config, layer_id=0): - super().__init__() - self.config = config - self.layer_id = layer_id - hidden_size = config.hidden_size - intermediate_size = ( - config.intermediate_size if config.intermediate_size is not None else 4 * config.hidden_size - ) - - self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) - self.time_mix_key = nn.Parameter(torch.empty(1, 1, hidden_size)) - self.time_mix_receptance = nn.Parameter(torch.empty(1, 1, hidden_size)) - - self.key = nn.Linear(hidden_size, intermediate_size, bias=False) - self.receptance = nn.Linear(hidden_size, hidden_size, bias=False) - self.value = nn.Linear(intermediate_size, hidden_size, bias=False) - - def forward(self, hidden, state=None): - if hidden.size(1) == 1 and state is not None: - shifted = state[0][:, :, self.layer_id] - else: - shifted = self.time_shift(hidden) - if state is not None: - shifted[:, 0] = state[0][:, :, self.layer_id] - key = hidden * self.time_mix_key + shifted * (1 - self.time_mix_key) - receptance = hidden * self.time_mix_receptance + shifted * (1 - self.time_mix_receptance) - - key = torch.square(torch.relu(self.key(key))) - value = self.value(key) - receptance = torch.sigmoid(self.receptance(receptance)) - - if state is not None: - state[0][:, :, self.layer_id] = hidden[:, -1] - - return receptance * value, state - - -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvBlock with Rwkv->Mamba class MambaBlock(nn.Module): def __init__(self, config, layer_id): super().__init__() self.config = config self.layer_id = layer_id - if layer_id == 0: - self.pre_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - self.attention = MambaSelfAttention(config, layer_id) - self.feed_forward = MambaFeedForward(config, layer_id) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward(self, hidden, state=None, use_cache=False, output_attentions=False): if self.layer_id == 0: @@ -396,7 +276,7 @@ def forward(self, hidden, state=None, use_cache=False, output_attentions=False): return outputs -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvPreTrainedModel with Rwkv->Mamba,rwkv->mamba +# Copied from transformers.models.mamba.modeling_mamba.mambaPreTrainedModel with mamba->Mamba,mamba->mamba class MambaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -411,7 +291,7 @@ class MambaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" - if isinstance(module, MambaSelfAttention): + if isinstance(module, MambaMixer): layer_id = module.layer_id num_hidden_layers = module.config.num_hidden_layers hidden_size = module.config.hidden_size @@ -448,27 +328,32 @@ def _init_weights(self, module): module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0) - elif isinstance(module, MambaFeedForward): - layer_id = module.layer_id - num_hidden_layers = module.config.num_hidden_layers - hidden_size = module.config.hidden_size - - ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 - - time_weight = torch.tensor( - [i / hidden_size for i in range(hidden_size)], - dtype=module.time_mix_key.dtype, - device=module.time_mix_key.device, - ) - time_weight = time_weight[None, None, :] - - with torch.no_grad(): - module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) - module.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.n_residuals_per_layer * self.config.n_layer) @dataclass -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvOutput with RWKV->MAMBA,Rwkv->Mamba class MambaOutput(ModelOutput): """ Class for the MAMBA model outputs. @@ -499,7 +384,6 @@ class MambaOutput(ModelOutput): @dataclass -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvCausalLMOutput with Rwkv->Mamba class MambaCausalLMOutput(ModelOutput): """ Base class for causal language model (or autoregressive) outputs. @@ -595,14 +479,13 @@ class MambaCausalLMOutput(ModelOutput): "The bare MAMBA Model transformer outputting raw hidden-states without any specific head on top.", MAMBA_START_DOCSTRING, ) -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvModel with RWKV->MAMBA,Rwkv->Mamba class MambaModel(MambaPreTrainedModel): def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.blocks = nn.ModuleList([MambaBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)]) - self.ln_out = nn.LayerNorm(config.hidden_size) + self.layers = nn.ModuleList([MambaBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)]) + self.norm_f = nn.LayerNorm(config.hidden_size) self.layers_are_rescaled = False @@ -654,13 +537,8 @@ def forward( if use_cache and state is None: shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) - state = [ - torch.zeros( - *shape, dtype=inputs_embeds.dtype if i <= 1 else torch.float32, device=inputs_embeds.device - ) - for i in range(5) - ] - state[4] -= 1e30 + dtype = inputs_embeds.dtype if i <= 1 else torch.float32 + state = [torch.zeros(*shape, dtype=dtype, device=inputs_embeds.device)for i in range(5)] if self.gradient_checkpointing and self.training: if use_cache: @@ -673,23 +551,16 @@ def forward( all_self_attentions = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - for idx, block in enumerate(self.blocks): + for idx, layer in enumerate(self.layers): if self.gradient_checkpointing and self.training: hidden_states, state, attentions = self._gradient_checkpointing_func( - block.__call__, hidden_states, state, use_cache, output_attentions + layer.__call__, hidden_states, state, use_cache, output_attentions ) else: - hidden_states, state, attentions = block( + hidden_states, state, attentions = layer( hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions ) - if ( - self.layers_are_rescaled - and self.config.rescale_every > 0 - and (idx + 1) % self.config.rescale_every == 0 - ): - hidden_states = hidden_states / 2 - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -764,7 +635,7 @@ def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): """, MAMBA_START_DOCSTRING, ) -# Copied from transformers.models.rwkv.modeling_rwkv.RwkvForCausalLM with RWKV->MAMBA,Rwkv->Mamba,rwkv->mamba +# Copied from transformers.models.mamba.modeling_mamba.mambaForCausalLM with mamba->MAMBA,mamba->Mamba,mamba->mamba class MambaForCausalLM(MambaPreTrainedModel): _tied_weights_keys = ["head.weight"] From 921bb24acbe8a0f2e601f2600da70c1756815ebe Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 1 Feb 2024 10:18:02 +0900 Subject: [PATCH 003/116] small nits --- src/transformers/models/mamba/configuration_mamba.py | 4 ++-- src/transformers/models/mamba/modeling_mamba.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 2ad117f4bd78e5..4553b65a5efd9c 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -64,7 +64,7 @@ class MambaConfig(PretrainedConfig): rescale_every (`int`, *optional*, defaults to 6): At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every `rescale_every` layer. If set to 0 or a negative number, no rescale is done. - tie_word_embeddings (`bool`, *optional*, defaults to `False`): + tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether or not to tie the word embeddings with the input token embeddings. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last state. @@ -100,7 +100,7 @@ def __init__( bos_token_id=0, eos_token_id=0, rescale_every=6, - tie_word_embeddings=False, + tie_word_embeddings=True, use_cache=True, **kwargs, ): diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index c3340b63537375..60b582cffb7997 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -642,16 +642,16 @@ class MambaForCausalLM(MambaPreTrainedModel): def __init__(self, config): super().__init__(config) self.mamba = MambaModel(config) - self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing self.post_init() def get_output_embeddings(self): - return self.head + return self.lm_head def set_output_embeddings(self, new_embeddings): - self.head = new_embeddings + self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): # only last token for inputs_ids if the state is passed along. From b3f216d242fde9c98dc302f9f8e8db79771cfba5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 3 Feb 2024 18:26:10 +0900 Subject: [PATCH 004/116] small nits --- .../models/mamba/modeling_mamba.py | 173 +++++++++--------- 1 file changed, 89 insertions(+), 84 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 60b582cffb7997..1b861d33fd9f5a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -1,6 +1,5 @@ # coding=utf-8 -# Copyright 2023 Bo Peng and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2024 Tri Dao, Albert Gu and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +24,7 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ...activations import ACT2FN from ...modeling_utils import PreTrainedModel from ...utils import ( ModelOutput, @@ -248,31 +248,87 @@ def mamba_linear_attention(time_decay, time_first, key, value, state=None, retur return MambaLinearAttention.apply(time_decay, time_first, key, value, state, return_state) +class MambaMixer(nn.Module): + + def __init__(self, config, layer_idx): + super().__init__() + self.d_model = config.d_model + self.d_state = config.d_state + self.d_conv = config.d_conv + self.expand = config.expand + self.d_inner = int(self.expand * self.d_model) + self.dt_rank = math.ceil(self.d_model / 16) if config.dt_rank == "auto" else config.dt_rank + self.use_fast_path = config.use_fast_path + self.layer_idx = layer_idx + + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=config.bias) + + self.conv1d = nn.Conv1d( + in_channels=self.d_inner, + out_channels=self.d_inner, + bias=config.conv_bias, + kernel_size=config.d_conv, + groups=self.d_inner, + padding=config.d_conv - 1, + ) + + self.activation = config.activation + self.act = ACT2FN[config.activation] + + + self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) + self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) + + # S4D real initialization + what_is_this = torch.arange(1, self.d_state + 1, dtype=torch.float32) + A = torch.repeat(what_is_this,d=self.d_inner).contiguous() + A_log = torch.log(A) # Keep A_log in fp32 + self.A_log = nn.Parameter(A_log) + self.A_log._no_weight_decay = True + + # D "skip" parameter + self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 + self.D._no_weight_decay = True + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.bias) + + def forward(self, hidden_states, inference_params): + """ + hidden_states: (B, L, D) + Returns: same shape as hidden_states + """ + batch, seqlen, dim = hidden_states.shape + + conv_state, ssm_state = None, None + if inference_params is not None: + conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) + if inference_params.seqlen_offset > 0: + # The states are updated inplace + out, _, _ = self.step(hidden_states, conv_state, ssm_state) + return out + + + + return hidden_states + + + class MambaBlock(nn.Module): def __init__(self, config, layer_id): super().__init__() self.config = config self.layer_id = layer_id - + self.residual_in_fp32 = config.residual_in_fp32 self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - def forward(self, hidden, state=None, use_cache=False, output_attentions=False): - if self.layer_id == 0: - hidden = self.pre_ln(hidden) - - attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache) - hidden = hidden + attention - - feed_forward, state = self.feed_forward(self.ln2(hidden), state=state) - hidden = hidden + feed_forward - - outputs = (hidden, state) - if output_attentions: - outputs += (attention,) - else: - outputs += (None,) - + self.mixer = MambaMixer(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward(self, hidden_states, residual=None, inference_params=None): + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mixer(hidden_states, inference_params=inference_params) + outputs = (hidden_states, residual) return outputs @@ -352,6 +408,8 @@ def _init_weights(self, module): with torch.no_grad(): p /= math.sqrt(self.config.n_residuals_per_layer * self.config.n_layer) + def _setup_cache(self, batch_size, max_seqlen, dtype): + raise NotImplementedError @dataclass class MambaOutput(ModelOutput): @@ -485,10 +543,9 @@ def __init__(self, config): self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([MambaBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)]) - self.norm_f = nn.LayerNorm(config.hidden_size) + self.norm_f = nn.LayerNorm(config.hidden_size) # ir use ALL_LAYER_NORM[config.hidden_states] self.layers_are_rescaled = False - self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -524,8 +581,6 @@ def forward( use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.training == self.layers_are_rescaled: - self._rescale_layers() if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -534,11 +589,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) - - if use_cache and state is None: + # TODO better to call _set_cache + if use_cache and cache is None: shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) dtype = inputs_embeds.dtype if i <= 1 else torch.float32 - state = [torch.zeros(*shape, dtype=dtype, device=inputs_embeds.device)for i in range(5)] + cache = [torch.zeros(*shape, dtype=dtype, device=inputs_embeds.device)for i in range(5)] if self.gradient_checkpointing and self.training: if use_cache: @@ -553,11 +608,11 @@ def forward( all_hidden_states = () if output_hidden_states else None for idx, layer in enumerate(self.layers): if self.gradient_checkpointing and self.training: - hidden_states, state, attentions = self._gradient_checkpointing_func( + hidden_states, cache, partial_states = self._gradient_checkpointing_func( layer.__call__, hidden_states, state, use_cache, output_attentions ) else: - hidden_states, state, attentions = layer( + hidden_states, cache, partial_states = layer( hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions ) @@ -565,7 +620,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if output_attentions: - all_self_attentions = all_self_attentions + (attentions,) + all_self_attentions = all_self_attentions + (partial_states,) hidden_states = self.ln_out(hidden_states) @@ -577,56 +632,11 @@ def forward( return MambaOutput( last_hidden_state=hidden_states, - state=state, + state=cache, hidden_states=all_hidden_states, attentions=all_self_attentions, ) - def _rescale_layers(self): - # Layers should be rescaled for inference only. - if self.layers_are_rescaled == (not self.training): - return - if self.config.rescale_every > 0: - with torch.no_grad(): - for block_id, block in enumerate(self.blocks): - if self.training: - block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every)) - block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every)) - else: - # Deal with quantization statistics - if hasattr(block.attention.output.weight, "SCB"): - block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) - block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) - elif hasattr(block.attention.output.weight, "quant_state"): - self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id) - self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id) - else: - block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every)) - block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every)) - - self.layers_are_rescaled = not self.training - - def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id): - r""" - Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will - be quantized again. - """ - if not is_bitsandbytes_available(): - raise ImportError("Please install bitsandbytes to use this method.") - import bitsandbytes as bnb - - dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state) - - dequant_weights.div_(2 ** int(block_id // self.config.rescale_every)) - - # re-quantize the model: - # we need to put it first on CPU then back to the device - # this will create an overhead :/ - # We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid - # bugs with bnb - quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device) - setattr(target_layer, "weight", quant_weight) - @add_start_docstrings( """ @@ -641,7 +651,7 @@ class MambaForCausalLM(MambaPreTrainedModel): def __init__(self, config): super().__init__(config) - self.mamba = MambaModel(config) + self.backbone = MambaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing @@ -676,11 +686,8 @@ def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=Non def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, # noqa inputs_embeds: Optional[torch.FloatTensor] = None, - state: Optional[List[torch.FloatTensor]] = None, labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -696,8 +703,6 @@ def forward( mamba_outputs = self.mamba( input_ids, inputs_embeds=inputs_embeds, - state=state, - use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -724,7 +729,7 @@ def forward( return MambaCausalLMOutput( loss=loss, logits=logits, - state=mamba_outputs.state, + cache=mamba_outputs.cache, hidden_states=mamba_outputs.hidden_states, attentions=mamba_outputs.attentions, ) From 7235b57f46293ae9d91fe4eed435ec11e8e1d80e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 3 Feb 2024 19:43:04 +0900 Subject: [PATCH 005/116] current updates --- .../models/mamba/configuration_mamba.py | 32 ++++----- .../models/mamba/modeling_mamba.py | 70 ++++++++++++++----- 2 files changed, 67 insertions(+), 35 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 4553b65a5efd9c..8c83cbe02ed503 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ MAMBA configuration""" - +import math from ...configuration_utils import PretrainedConfig from ...utils import logging @@ -90,32 +90,32 @@ class MambaConfig(PretrainedConfig): def __init__( self, - vocab_size=50277, - context_length=1024, - hidden_size=4096, + vocab_size=50280, + hidden_size=768, + state_size=16, num_hidden_layers=32, - attention_hidden_size=None, - intermediate_size=None, layer_norm_epsilon=1e-5, - bos_token_id=0, - eos_token_id=0, - rescale_every=6, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + expand=2, + dt_rank="auto", tie_word_embeddings=True, - use_cache=True, **kwargs, ): self.vocab_size = vocab_size - self.context_length = context_length self.hidden_size = hidden_size + self.state_size = state_size self.num_hidden_layers = num_hidden_layers - self.attention_hidden_size = attention_hidden_size if attention_hidden_size is not None else hidden_size - self.intermediate_size = intermediate_size if intermediate_size is not None else 4 * hidden_size self.layer_norm_epsilon = layer_norm_epsilon - self.rescale_every = rescale_every - self.use_cache = use_cache - + self.d_inner = hidden_size * 2 + self.d_conv = 4 + self.state_size = state_size + self.expand = expand + self.time_step_rank = math.ceil(self.hidden_size / 16) if dt_rank == "auto" else dt_rank self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id + self.pad_token_id = pad_token_id super().__init__( tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 1b861d33fd9f5a..6511e30fce3a26 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -257,7 +257,7 @@ def __init__(self, config, layer_idx): self.d_conv = config.d_conv self.expand = config.expand self.d_inner = int(self.expand * self.d_model) - self.dt_rank = math.ceil(self.d_model / 16) if config.dt_rank == "auto" else config.dt_rank + self.time_step_rank = math.ceil(self.d_model / 16) if config.time_step_rank == "auto" else config.time_step_rank self.use_fast_path = config.use_fast_path self.layer_idx = layer_idx @@ -275,11 +275,12 @@ def __init__(self, config, layer_idx): self.activation = config.activation self.act = ACT2FN[config.activation] - - self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False) - self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True) - - # S4D real initialization + # selective projection used to make dt, B and C input dependant + self.x_proj = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) + self.time_step_proj = nn.Linear(self.time_step_rank, self.d_inner, bias=True) + # S4D real initialization. These are not discretized! + # THe core is to load them, compute the discrete states, then write the updates state. + # Keeps the memory bounded what_is_this = torch.arange(1, self.d_state + 1, dtype=torch.float32) A = torch.repeat(what_is_this,d=self.d_inner).contiguous() A_log = torch.log(A) # Keep A_log in fp32 @@ -291,24 +292,55 @@ def __init__(self, config, layer_idx): self.D._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.bias) - def forward(self, hidden_states, inference_params): + def forward(self, hidden_states: torch.Tensor, inference_params=None): """ hidden_states: (B, L, D) Returns: same shape as hidden_states """ - batch, seqlen, dim = hidden_states.shape - - conv_state, ssm_state = None, None - if inference_params is not None: - conv_state, ssm_state = self._get_states_from_cache(inference_params, batch) - if inference_params.seqlen_offset > 0: - # The states are updated inplace - out, _, _ = self.step(hidden_states, conv_state, ssm_state) - return out - - + _, seqlen, _ = hidden_states.shape + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + + projected_states = self.in_proj(hidden_states).transpose(1,2) + x, z = projected_states.chunk(2, dim=1) - return hidden_states + if inference_params is not None and inference_params.seq_offset > 0: + x = causal_conv1d_update( + x, + conv_state, + self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + self.conv1d.bias, + self.activation, + ) + else: + conv_state = F.pad(x, (self.d_conv - seqlen, 0)) + x = causal_conv1d_fn( + x=x, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + if inference_params is not None and inference_params.seq_offset > 0: + y, _ = selective_scan_update( + ssm_state, x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + else: + y, last_state = selective_scan_fn( + x, dt, self.negA, B, C, self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=True, + ) + y = rearrange(y, "b d l -> b l d") + attn_outputs = self.out_proj(y) + return attn_outputs, conv_state, last_state + From 7a407a764b6c33e6f033b4e1654f24ad4f48821d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 5 Feb 2024 15:22:53 +0900 Subject: [PATCH 006/116] add kernels --- .../mamba/selective_scan/reverse_scan.cuh | 401 +++++++++++++ .../mamba/selective_scan/selective_scan.cpp | 497 ++++++++++++++++ .../mamba/selective_scan/selective_scan.h | 101 ++++ .../selective_scan_bwd_bf16_complex.cu | 9 + .../selective_scan_bwd_bf16_real.cu | 9 + .../selective_scan_bwd_fp16_complex.cu | 9 + .../selective_scan_bwd_fp16_real.cu | 9 + .../selective_scan_bwd_fp32_complex.cu | 9 + .../selective_scan_bwd_fp32_real.cu | 9 + .../selective_scan_bwd_kernel.cuh | 531 ++++++++++++++++++ .../selective_scan/selective_scan_common.h | 221 ++++++++ .../selective_scan/selective_scan_fwd_bf16.cu | 10 + .../selective_scan/selective_scan_fwd_fp16.cu | 10 + .../selective_scan/selective_scan_fwd_fp32.cu | 10 + .../selective_scan_fwd_kernel.cuh | 345 ++++++++++++ .../mamba/selective_scan/static_switch.h | 25 + .../selective_scan/uninitialized_copy.cuh | 69 +++ 17 files changed, 2274 insertions(+) create mode 100644 src/transformers/kernels/mamba/selective_scan/reverse_scan.cuh create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan.cpp create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan.h create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_complex.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_real.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_complex.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_real.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_complex.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_real.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_kernel.cuh create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_common.h create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_bf16.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp16.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp32.cu create mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_kernel.cuh create mode 100644 src/transformers/kernels/mamba/selective_scan/static_switch.h create mode 100644 src/transformers/kernels/mamba/selective_scan/uninitialized_copy.cuh diff --git a/src/transformers/kernels/mamba/selective_scan/reverse_scan.cuh b/src/transformers/kernels/mamba/selective_scan/reverse_scan.cuh new file mode 100644 index 00000000000000..d7e93174bb391d --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/reverse_scan.cuh @@ -0,0 +1,401 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +// #include +#include "uninitialized_copy.cuh" + +/** + * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ReductionOp> +__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { + static_assert(LENGTH > 0); + T retval = input[LENGTH - 1]; + #pragma unroll + for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } + return retval; +} + +/** + * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanInclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + T inclusive = postfix; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(inclusive, input[i]); + output[i] = inclusive; + } +} + +/** + * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. + */ +template < + int LENGTH, + typename T, + typename ScanOp> +__device__ __forceinline__ T ThreadReverseScanExclusive( + const T (&input)[LENGTH], + T (&output)[LENGTH], + ScanOp scan_op, + const T postfix) +{ + // Careful, output maybe be aliased to input + T exclusive = postfix; + T inclusive; + #pragma unroll + for (int i = LENGTH - 1; i >= 0; --i) { + inclusive = scan_op(exclusive, input[i]); + output[i] = exclusive; + exclusive = inclusive; + } + return inclusive; +} + + +/** + * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. + * + * LOGICAL_WARP_THREADS must be a power-of-two + */ +template < + typename T, ///< Data type being scanned + int LOGICAL_WARP_THREADS ///< Number of threads per logical warp + > +struct WarpReverseScan { + //--------------------------------------------------------------------- + // Constants and type definitions + //--------------------------------------------------------------------- + + /// Whether the logical warp size and the PTX warp size coincide + static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); + /// The number of warp scan steps + static constexpr int STEPS = cub::Log2::VALUE; + static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); + + + //--------------------------------------------------------------------- + // Thread fields + //--------------------------------------------------------------------- + + /// Lane index in logical warp + unsigned int lane_id; + + /// Logical warp index in 32-thread physical warp + unsigned int warp_id; + + /// 32-thread physical warp member mask of logical warp + unsigned int member_mask; + + //--------------------------------------------------------------------- + // Construction + //--------------------------------------------------------------------- + + /// Constructor + explicit __device__ __forceinline__ + WarpReverseScan() + : lane_id(cub::LaneId()) + , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) + , member_mask(cub::WarpMask(warp_id)) + { + if (!IS_ARCH_WARP) { + lane_id = lane_id % LOGICAL_WARP_THREADS; + } + } + + + /// Broadcast + __device__ __forceinline__ T Broadcast( + T input, ///< [in] The value to broadcast + int src_lane) ///< [in] Which warp lane is to do the broadcasting + { + return cub::ShuffleIndex(input, src_lane, member_mask); + } + + + /// Inclusive scan + template + __device__ __forceinline__ void InclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op) ///< [in] Binary scan operator + { + inclusive_output = input; + #pragma unroll + for (int STEP = 0; STEP < STEPS; STEP++) { + int offset = 1 << STEP; + T temp = cub::ShuffleDown( + inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask + ); + // Perform scan op if from a valid peer + inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset + ? inclusive_output : scan_op(temp, inclusive_output); + } + } + + /// Exclusive scan + // Get exclusive from inclusive + template + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item. + T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. + ScanOpT scan_op, ///< [in] Binary scan operator + T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. + { + T inclusive_output; + InclusiveReverseScan(input, inclusive_output, scan_op); + warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + + /** + * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. + */ + template + __device__ __forceinline__ void ReverseScan( + T input, ///< [in] Calling thread's input item. + T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. + T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. + ScanOpT scan_op) ///< [in] Binary scan operator + { + InclusiveReverseScan(input, inclusive_output, scan_op); + // initial value unknown + exclusive_output = cub::ShuffleDown( + inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask + ); + } + +}; + +/** + * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. + */ +template < + typename T, ///< Data type being scanned + int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension + bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure + > +struct BlockReverseScan { + //--------------------------------------------------------------------- + // Types and constants + //--------------------------------------------------------------------- + + /// Constants + /// The thread block size in threads + static constexpr int BLOCK_THREADS = BLOCK_DIM_X; + + /// Layout type for padded thread block raking grid + using BlockRakingLayout = cub::BlockRakingLayout; + // The number of reduction elements is not a multiple of the number of raking threads for now + static_assert(BlockRakingLayout::UNGUARDED); + + /// Number of raking threads + static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; + /// Number of raking elements per warp synchronous raking thread + static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; + /// Cooperative work can be entirely warp synchronous + static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); + + /// WarpReverseScan utility type + using WarpReverseScan = WarpReverseScan; + + /// Shared memory storage layout type + struct _TempStorage { + typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid + }; + + + /// Alias wrapper allowing storage to be unioned + struct TempStorage : cub::Uninitialized<_TempStorage> {}; + + + //--------------------------------------------------------------------- + // Per-thread fields + //--------------------------------------------------------------------- + + // Thread fields + _TempStorage &temp_storage; + unsigned int linear_tid; + T cached_segment[SEGMENT_LENGTH]; + + + //--------------------------------------------------------------------- + // Utility methods + //--------------------------------------------------------------------- + + /// Performs upsweep raking reduction, returning the aggregate + template + __device__ __forceinline__ T Upsweep(ScanOp scan_op) { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data into registers + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; + #pragma unroll + for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { + raking_partial = scan_op(raking_partial, cached_segment[i]); + } + return raking_partial; + } + + + /// Performs exclusive downsweep raking scan + template + __device__ __forceinline__ void ExclusiveDownsweep( + ScanOp scan_op, + T raking_partial) + { + T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); + // Read data back into registers + if (!MEMOIZE) { + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } + } + ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); + // Write data back to smem + #pragma unroll + for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } + } + + + //--------------------------------------------------------------------- + // Constructors + //--------------------------------------------------------------------- + + /// Constructor + __device__ __forceinline__ BlockReverseScan( + TempStorage &temp_storage) + : + temp_storage(temp_storage.Alias()), + linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) + {} + + + /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + template < + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void ExclusiveReverseScan( + T input, ///< [in] Calling thread's input item + T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan operator + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. + { + if (WARP_SYNCHRONOUS) { + // Short-circuit directly to warp-synchronous scan + T block_aggregate; + WarpReverseScan warp_scan; + warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); + // Obtain warp-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); + } else { + // Place thread partial into shared memory raking grid + T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); + detail::uninitialized_copy(placement_ptr, input); + cub::CTA_SYNC(); + // Reduce parallelism down to just raking threads + if (linear_tid < RAKING_THREADS) { + WarpReverseScan warp_scan; + // Raking upsweep reduction across shared partials + T upsweep_partial = Upsweep(scan_op); + // Warp-synchronous scan + T exclusive_partial, block_aggregate; + warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); + // Obtain block-wide postfix in lane0, then broadcast to other lanes + T block_postfix = block_postfix_callback_op(block_aggregate); + block_postfix = warp_scan.Broadcast(block_postfix, 0); + // Update postfix with warpscan exclusive partial + T downsweep_postfix = linear_tid == RAKING_THREADS - 1 + ? block_postfix : scan_op(block_postfix, exclusive_partial); + // Exclusive raking downsweep scan + ExclusiveDownsweep(scan_op, downsweep_postfix); + } + cub::CTA_SYNC(); + // Grab thread postfix from shared memory + exclusive_output = *placement_ptr; + + // // Compute warp scan in each warp. + // // The exclusive output from the last lane in each warp is invalid. + // T inclusive_output; + // WarpReverseScan warp_scan; + // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); + + // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. + // T block_aggregate; + // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); + + // // Apply warp postfix to our lane's partial + // if (warp_id != 0) { + // exclusive_output = scan_op(warp_postfix, exclusive_output); + // if (lane_id == 0) { exclusive_output = warp_postfix; } + // } + + // // Use the first warp to determine the thread block postfix, returning the result in lane0 + // if (warp_id == 0) { + // T block_postfix = block_postfix_callback_op(block_aggregate); + // if (lane_id == 0) { + // // Share the postfix with all threads + // detail::uninitialized_copy(&temp_storage.block_postfix, + // block_postfix); + + // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 + // } + // } + + // cub::CTA_SYNC(); + + // // Incorporate thread block postfix into outputs + // T block_postfix = temp_storage.block_postfix; + // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } + } + } + + + /** + * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. + */ + template < + int ITEMS_PER_THREAD, + typename ScanOp, + typename BlockPostfixCallbackOp> + __device__ __forceinline__ void InclusiveReverseScan( + T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items + T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) + ScanOp scan_op, ///< [in] Binary scan functor + BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. + { + // Reduce consecutive thread items in registers + T thread_postfix = ThreadReverseReduce(input, scan_op); + // Exclusive thread block-scan + ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); + // Inclusive scan in registers with postfix as seed + ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); + } + +}; \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan.cpp b/src/transformers/kernels/mamba/selective_scan/selective_scan.cpp new file mode 100644 index 00000000000000..cde867cd32d39b --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan.cpp @@ -0,0 +1,497 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#include +#include +#include +#include + +#include "selective_scan.h" + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Half) { \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::BFloat16) { \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ + if (WTYPE == at::ScalarType::Float) { \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (WTYPE == at::ScalarType::ComplexFloat) { \ + using weight_t = c10::complex; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ + } + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor out, + const at::Tensor z, + const at::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +void set_ssm_params_bwd(SSMParamsBwd ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const at::Tensor u, + const at::Tensor delta, + const at::Tensor A, + const at::Tensor B, + const at::Tensor C, + const at::Tensor z, + const at::Tensor out, + const at::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + const at::Tensor dout, + const at::Tensor du, + const at::Tensor ddelta, + const at::Tensor dA, + const at::Tensor dB, + const at::Tensor dC, + const at::Tensor dz, + void* dD_ptr, + void* ddelta_bias_ptr, + bool has_z, + bool delta_softplus, + bool recompute_out_z) { + // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z + set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, has_z ? out : dout, + has_z ? z : dout, + // If not recompute_out_z, pass dout instead of out_z. + // This won't be used by the bwd kernel + recompute_out_z ? out_z : dout, + D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); + if (!recompute_out_z) { params.out_z_ptr = nullptr; } + + // Set the pointers and strides. + params.dout_ptr = dout.data_ptr(); + params.du_ptr = du.data_ptr(); + params.dA_ptr = dA.data_ptr(); + params.dB_ptr = dB.data_ptr(); + params.dC_ptr = dC.data_ptr(); + params.dD_ptr = dD_ptr; + params.ddelta_ptr = ddelta.data_ptr(); + params.ddelta_bias_ptr = ddelta_bias_ptr; + params.dz_ptr = has_z ? dz.data_ptr() : nullptr; + // All stride are in elements, not bytes. + params.dout_batch_stride = dout.stride(0); + params.dout_d_stride = dout.stride(1); + params.dA_d_stride = dA.stride(0); + params.dA_dstate_stride = dA.stride(1); + if (!is_variable_B) { + params.dB_d_stride = dB.stride(0); + } else { + params.dB_batch_stride = dB.stride(0); + params.dB_group_stride = dB.stride(1); + } + params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2); + if (!is_variable_C) { + params.dC_d_stride = dC.stride(0); + } else { + params.dC_batch_stride = dC.stride(0); + params.dC_group_stride = dC.stride(1); + } + params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2); + params.du_batch_stride = du.stride(0); + params.du_d_stride = du.stride(1); + params.ddelta_batch_stride = ddelta.stride(0); + params.ddelta_d_stride = ddelta.stride(1); + if (has_z) { + params.dz_batch_stride = dz.stride(0); + params.dz_d_stride = dz.stride(1); + } +} + +std::vector +selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + at::Tensor x; + x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.data_ptr(), + has_z, + delta_softplus); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); + }); + }); + std::vector result = {out, x}; + if (has_z) { result.push_back(out_z); } + return result; +} + +std::vector +selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, + const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + const at::Tensor &dout, + const c10::optional &x_, + const c10::optional &out_, + c10::optional &dz_, + bool delta_softplus, + bool recompute_out_z) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + const bool is_complex = weight_type == at::ScalarType::ComplexFloat; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + TORCH_CHECK(dout.scalar_type() == input_type); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + TORCH_CHECK(dout.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + if (!is_variable_B) { + CHECK_SHAPE(B, dim, dstate); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + } + if (!is_variable_C) { + CHECK_SHAPE(C, dim, dstate); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + } + CHECK_SHAPE(dout, batch_size, dim, seqlen); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + at::Tensor z, out, dz, out_z; + const bool has_z = z_.has_value(); + if (has_z) { + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + + TORCH_CHECK(out_.has_value()); + out = out_.value(); + TORCH_CHECK(out.scalar_type() == input_type); + TORCH_CHECK(out.is_cuda()); + TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1); + CHECK_SHAPE(out, batch_size, dim, seqlen); + + if (dz_.has_value()) { + dz = dz_.value(); + TORCH_CHECK(dz.scalar_type() == input_type); + TORCH_CHECK(dz.is_cuda()); + TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1); + CHECK_SHAPE(dz, batch_size, dim, seqlen); + } else { + dz = torch::empty_like(z); + } + if (recompute_out_z) { + out_z = torch::empty_like(out); + } + } + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } + if (x_.has_value()) { + auto x = x_.value(); + TORCH_CHECK(x.scalar_type() == weight_type); + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x.is_contiguous()); + CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); + } + + at::Tensor du = torch::empty_like(u); + at::Tensor ddelta = torch::empty_like(delta); + at::Tensor dA = torch::zeros_like(A); + at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32)); + at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32)); + at::Tensor dD; + if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } + at::Tensor ddelta_bias; + if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } + + SSMParamsBwd params; + set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, z, out, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x_.has_value() ? x_.value().data_ptr() : nullptr, + dout, du, ddelta, dA, dB, dC, dz, + D_.has_value() ? dD.data_ptr() : nullptr, + delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, + has_z, delta_softplus, recompute_out_z); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { + DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { + selective_scan_bwd_cuda(params, stream); + }); + }); + std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; + if (has_z) { result.push_back(dz); } + if (recompute_out_z) { result.push_back(out_z); } + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("fwd", &selective_scan_fwd, "Selective scan forward"); + m.def("bwd", &selective_scan_bwd, "Selective scan backward"); +} diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan.h b/src/transformers/kernels/mamba/selective_scan/selective_scan.h new file mode 100644 index 00000000000000..e2c7bcdbd5ddad --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan.h @@ -0,0 +1,101 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMScanParamsBase { + using index_t = uint32_t; + + int batch, seqlen, n_chunks; + index_t a_batch_stride; + index_t b_batch_stride; + index_t out_batch_stride; + + // Common data pointers. + void *__restrict__ a_ptr; + void *__restrict__ b_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; +}; + +struct SSMParamsBwd: public SSMParamsBase { + index_t dout_batch_stride; + index_t dout_d_stride; + index_t dA_d_stride; + index_t dA_dstate_stride; + index_t dB_batch_stride; + index_t dB_group_stride; + index_t dB_d_stride; + index_t dB_dstate_stride; + index_t dC_batch_stride; + index_t dC_group_stride; + index_t dC_d_stride; + index_t dC_dstate_stride; + index_t du_batch_stride; + index_t du_d_stride; + index_t dz_batch_stride; + index_t dz_d_stride; + index_t ddelta_batch_stride; + index_t ddelta_d_stride; + + // Common data pointers. + void *__restrict__ dout_ptr; + void *__restrict__ dA_ptr; + void *__restrict__ dB_ptr; + void *__restrict__ dC_ptr; + void *__restrict__ dD_ptr; + void *__restrict__ du_ptr; + void *__restrict__ dz_ptr; + void *__restrict__ ddelta_ptr; + void *__restrict__ ddelta_bias_ptr; +}; diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_complex.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_complex.cu new file mode 100644 index 00000000000000..c55f0e858af4eb --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_complex.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_real.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_real.cu new file mode 100644 index 00000000000000..72adaf5cb13c64 --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_real.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_complex.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_complex.cu new file mode 100644 index 00000000000000..df126d7c8d5f9f --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_complex.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_real.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_real.cu new file mode 100644 index 00000000000000..3ff271b50eaff2 --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_real.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_complex.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_complex.cu new file mode 100644 index 00000000000000..5554902342785b --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_complex.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_real.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_real.cu new file mode 100644 index 00000000000000..a7ed642231da80 --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_real.cu @@ -0,0 +1,9 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_bwd_kernel.cuh" + +template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_kernel.cuh b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_kernel.cuh new file mode 100644 index 00000000000000..2ed101148a4b32 --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_kernel.cuh @@ -0,0 +1,531 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK +#include // For atomicAdd on complex + +#include +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "reverse_scan.cuh" +#include "static_switch.h" + +template __device__ __forceinline__ scalar_t conj(scalar_t x); +template<> __device__ __forceinline__ float conj(float x) { return x; } +template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } + +template +struct Selective_Scan_bwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kNItems = kNItems_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; + static constexpr bool kHasZ = kHasZ_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. + // For complex this would lead to massive register spilling, so we keep it at 2. + static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockReverseScanT = BlockReverseScan; + using BlockReduceT = cub::BlockReduce; + using BlockReduceFloatT = cub::BlockReduce; + using BlockReduceComplexT = cub::BlockReduce; + using BlockExchangeT = cub::BlockExchange; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); + static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_bwd_kernel(SSMParamsBwd params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); + auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); + auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); + auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); + auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); + auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); + weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); + scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * MAX_DSTATE + kNThreads); + weight_t *smem_da = reinterpret_cast(smem_running_postfix + MAX_DSTATE); + weight_t *smem_dbc = reinterpret_cast(smem_da + MAX_DSTATE); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * params.delta_d_stride; + input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride + + dim_id * params.dout_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; + weight_t *dB = reinterpret_cast(params.dB_ptr) + + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); + weight_t *dC = reinterpret_cast(params.dC_ptr) + + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); + float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; + float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; + float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; + float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; + scan_t *x = params.x_ptr == nullptr + ? nullptr + : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; + float dD_val = 0; + float ddelta_bias_val = 0; + + constexpr int kChunkSize = kNThreads * kNItems; + u += (params.n_chunks - 1) * kChunkSize; + delta += (params.n_chunks - 1) * kChunkSize; + dout += (params.n_chunks - 1) * kChunkSize; + Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); + for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { + input_t u_vals[kNItems]; + input_t delta_vals_load[kNItems]; + input_t dout_vals_load[kNItems]; + __syncthreads(); + load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); + u -= kChunkSize; + __syncthreads(); + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + // Will reload delta at the same location if kDeltaSoftplus + if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } + __syncthreads(); + load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + dout -= kChunkSize; + + float dout_vals[kNItems], delta_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dout_vals[i] = float(dout_vals_load[i]); + delta_vals[i] = float(delta_vals_load[i]) + delta_bias; + if constexpr (kDeltaSoftplus) { + delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; + } + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * params.z_d_stride + chunk * kChunkSize; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * params.out_d_stride + chunk * kChunkSize; + input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride + + dim_id * params.dz_d_stride + chunk * kChunkSize; + input_t z_vals[kNItems], out_vals[kNItems]; + __syncthreads(); + load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + __syncthreads(); + load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); + float dz_vals[kNItems], z_silu_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); + z_silu_vals[i] = z_val * z_sigmoid_val; + dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val + * (1.0f + z_val * (1.0f - z_sigmoid_val)); + dout_vals[i] *= z_silu_vals[i]; + } + __syncthreads(); + store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); + if (params.out_z_ptr != nullptr) { // Recompute and store out_z + float out_z_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } + // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { + // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); + // } + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * params.out_z_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); + } + } + + float du_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } + + float ddelta_vals[kNItems] = {0}; + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + const weight_t A_val = A[state_idx * params.A_dstate_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + weight_t A_scaled; + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_scaled = A_val * kLog2e; + } else { + A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); + } + weight_t B_val, C_val; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (!kIsVariableB) { + B_val = B[state_idx * params.B_dstate_stride]; + } else { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + if constexpr (!kIsVariableC) { + C_val = C[state_idx * params.C_dstate_stride]; + } else { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + } + // const weight_t A_val = smem_a[state_idx]; + scan_t thread_data[kNItems], thread_reverse_data[kNItems]; + if constexpr (!kIsComplex) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); + thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp; + } + thread_reverse_data[i].y = dout_vals[i] * + (!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + } + __syncthreads(); + thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) + : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const float dx = thread_reverse_data[i].y; + const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; + du_vals[i] += ddelta_u * delta_vals[i]; + const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; + dA_val += dx * delta_vals[i] * a; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += dout_vals[i] * thread_data[i].y; + } + } + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + if constexpr (kIsVariableB) { + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); + } + if constexpr (kIsVariableC) { + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); + } + const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; + weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; + weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float2 dA_dBC_val = make_float2(dA_val, dBC_val); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = dA_dBC_val.x; + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } else { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); + weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if (i == 0) { + smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; + } else { + thread_reverse_data[i - 1].x = delta_a_exp.real_; + thread_reverse_data[i - 1].y = -delta_a_exp.imag_; + } + complex_t dout_BC = 2 * dout_vals[i] + * conj(!kIsVariableC + ? (!kIsVariableB ? B_val * C_val : C_val) + : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); + thread_reverse_data[i].z = dout_BC.real_; + thread_reverse_data[i].w = dout_BC.imag_; + } + __syncthreads(); + complex_t delta_a_exp = threadIdx.x == kNThreads - 1 + ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) + : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; + thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; + thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; + // Initialize running total + scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + SSMScanPrefixCallbackOp postfix_op(running_postfix); + Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( + thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op + ); + if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } + weight_t dA_val = 0, dBC_val = 0; + weight_t dB_vals[kNItems], dC_vals[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + complex_t x = complex_t(thread_data[i].z, thread_data[i].w); + complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); + float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; + if constexpr (!kIsVariableB || !kIsVariableC) { + if constexpr (!kIsVariableB) { // dBC_val is dB_val + dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); + } else { // dBC_val is dC_val + dBC_val += (2 * dout_vals[i]) * conj(x); + } + } + const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); + du_vals[i] += ddelta_u * delta_vals[i]; + ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; + dA_val += delta_vals[i] * dx * a_conj; + if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } + if constexpr (kIsVariableC) { + dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); + } + } + // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower + if constexpr (kIsVariableB || kIsVariableC) { + float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; + if constexpr (kIsVariableB) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dB_vals_f[i * 2] = dB_vals[i].real_; + dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; + } + Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); + } + if constexpr (kIsVariableC) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + dC_vals_f[i * 2] = dC_vals[i].real_; + dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; + } + auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; + Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); + } + const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; + float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; + #pragma unroll + for (int i = 0; i < kNItems * 2; ++i) { + if (i * kNThreads < seqlen_remaining) { + if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } + if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } + } + } + } + if constexpr (!kIsVariableB || !kIsVariableC) { + float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); + dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); + dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); + dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); + if (threadIdx.x == 0) { + smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; + } + } else { + dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); + } + if (threadIdx.x == 0) { + smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; + } + } + } + + if constexpr (kDeltaSoftplus) { + __syncthreads(); + input_t delta_vals_load[kNItems]; + load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); + delta -= kChunkSize; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float delta_val = float(delta_vals_load[i]) + delta_bias; + float delta_val_neg_exp = expf(-delta_val); + ddelta_vals[i] = delta_val <= 20.f + ? ddelta_vals[i] / (1.f + delta_val_neg_exp) + : ddelta_vals[i]; + } + } + for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } + + input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride + + dim_id * params.du_d_stride + chunk * kChunkSize; + input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride + + dim_id * params.ddelta_d_stride + chunk * kChunkSize; + __syncthreads(); + store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); + __syncthreads(); + store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); + + Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); + Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); + } + if (params.dD_ptr != nullptr) { + dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); + if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } + } + if (params.ddelta_bias_ptr != nullptr) { + __syncthreads(); + ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); + if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } + } + for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); + weight_t dBC_val; + if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } + if constexpr (!kIsVariableB) { + gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), + !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); + } + if constexpr (!kIsVariableC) { + gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), + !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); + } + } +} + +template +void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_bwd_kernel_traits; + // using Ktraits = Selective_Scan_bwd_kernel_traits; + // TODO: check this + constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim); + auto kernel = &selective_scan_bwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); + }); +} + +template +void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); + } +} \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_common.h b/src/transformers/kernels/mamba/selective_scan/selective_scan_common.h new file mode 100644 index 00000000000000..9140dcdf3b68ad --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_common.h @@ -0,0 +1,221 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For scalar_value_type + +#define MAX_DSTATE 256 + +using complex_t = c10::complex; + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp +// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 +__device__ __forceinline__ complex_t cexp2f(complex_t z) { + float t = exp2f(z.real_); + float c, s; + sincosf(z.imag_, &s, &c); + return complex_t(c * t, s * t); +} + +__device__ __forceinline__ complex_t cexpf(complex_t z) { + float t = expf(z.real_); + float c, s; + sincosf(z.imag_, &s, &c); + return complex_t(c * t, s * t); +} + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { + complex_t a0 = complex_t(ab0.x, ab0.y); + complex_t b0 = complex_t(ab0.z, ab0.w); + complex_t a1 = complex_t(ab1.x, ab1.y); + complex_t b1 = complex_t(ab1.z, ab1.w); + complex_t out_a = a1 * a0; + complex_t out_b = a1 * b0 + b1; + return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + if constexpr (!Ktraits::kIsComplex) { + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); + } else { + typename Ktraits::input_t B_vals_load[kNItems * 2]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } + } +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_bf16.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_bf16.cu new file mode 100644 index 00000000000000..2b8615b1d522c1 --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_bf16.cu @@ -0,0 +1,10 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp16.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp16.cu new file mode 100644 index 00000000000000..015e2a0eff633d --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp16.cu @@ -0,0 +1,10 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp32.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp32.cu new file mode 100644 index 00000000000000..c142fe0208ea78 --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp32.cu @@ -0,0 +1,10 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +// Split into multiple files to compile in paralell + +#include "selective_scan_fwd_kernel.cuh" + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_kernel.cuh b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_kernel.cuh new file mode 100644 index 00000000000000..440a209108bfe1 --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_kernel.cuh @@ -0,0 +1,345 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include +#include + +#include "selective_scan.h" +#include "selective_scan_common.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsComplex = std::is_same_v; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + + using vec_t = typename BytesToType::Type; + using scan_t = std::conditional_t; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsComplex = Ktraits::kIsComplex; + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + if constexpr (!kIsComplex) { + A_val[r] *= kLog2e; + } else { + A_val[r].real_ *= kLog2e; + } + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + if constexpr (!kIsComplex) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } else { + // Pytorch's implementation of complex exp (which calls thrust) is very slow + complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); + weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; + thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); + } + } + } + } + // Initialize running total + scan_t running_prefix; + if constexpr (!kIsComplex) { + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + } else { + running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); + } + SSMScanPrefixCallbackOp prefix_op(running_prefix); + Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + if constexpr (!kIsComplex) { + out_vals[r][i] += thread_data[i].y * C_val; + } else { + out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2; + } + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * (!kIsComplex ? 1 : 2); + Cvar += kChunkSize * (!kIsComplex ? 1 : 2); + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { + BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { + BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + // printf("smem_size = %d\n", kSmemSize); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } +} diff --git a/src/transformers/kernels/mamba/selective_scan/static_switch.h b/src/transformers/kernels/mamba/selective_scan/static_switch.h new file mode 100644 index 00000000000000..7920ac045d0a2a --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/src/transformers/kernels/mamba/selective_scan/uninitialized_copy.cuh b/src/transformers/kernels/mamba/selective_scan/uninitialized_copy.cuh new file mode 100644 index 00000000000000..630622dddcc904 --- /dev/null +++ b/src/transformers/kernels/mamba/selective_scan/uninitialized_copy.cuh @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include + +#include + + +namespace detail +{ + +#if defined(_NVHPC_CUDA) +template +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + // NVBug 3384810 + new (ptr) T(::cuda::std::forward(val)); +} +#else +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + *ptr = ::cuda::std::forward(val); +} + +template ::value, + int + >::type = 0> +__host__ __device__ void uninitialized_copy(T *ptr, U &&val) +{ + new (ptr) T(::cuda::std::forward(val)); +} +#endif + +} // namespace detail From 9f2a98296c888df38c04be45a80e8582f3953d61 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 5 Feb 2024 16:20:33 +0900 Subject: [PATCH 007/116] small refactoring little step --- .../models/mamba/modeling_mamba.py | 72 ++++++++++++++++--- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 6511e30fce3a26..28fb27c73eac3e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -301,20 +301,20 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] projected_states = self.in_proj(hidden_states).transpose(1,2) - x, z = projected_states.chunk(2, dim=1) + hidden_states, z = projected_states.chunk(2, dim=1) if inference_params is not None and inference_params.seq_offset > 0: - x = causal_conv1d_update( - x, + hidden_states = causal_conv1d_update( + hidden_states, conv_state, self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), self.conv1d.bias, self.activation, ) else: - conv_state = F.pad(x, (self.d_conv - seqlen, 0)) - x = causal_conv1d_fn( - x=x, + conv_state = F.pad(hidden_states, (self.d_conv - seqlen, 0)) + hidden_states = causal_conv1d_fn( + hidden_states=hidden_states, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation, @@ -323,7 +323,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): # We're careful here about the layout, to avoid extra transposes. # We want dt to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + x_dbl = self.x_proj(rearrange(hidden_states, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) @@ -331,17 +331,69 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() if inference_params is not None and inference_params.seq_offset > 0: y, _ = selective_scan_update( - ssm_state, x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ssm_state, hidden_states, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True ) else: y, last_state = selective_scan_fn( - x, dt, self.negA, B, C, self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=True, + hidden_states, dt, self.negA, B, C, self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=True, ) y = rearrange(y, "b d l -> b l d") attn_outputs = self.out_proj(y) return attn_outputs, conv_state, last_state + + +class MambaCache: + + def __init__(self): + pass +class MambaSlowMixer(MambaMixer): + + def forward(self, hidden_states, infer_params=MambaCache()): + batch_size, seqlen, _ = hidden_states.shape + + # 1. Gated MLP's linear projection + projected_states = self.in_proj(hidden_states).transpose(1,2) + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + if infer_params is not None: + conv_state = infer_params.update_conv_states(hidden_states, ssm_state) + + hidden_states = torch.sum(conv_state * torch.rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + if self.conv1d.bias is not None: + hidden_states = hidden_states + self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype=hidden_states.dtype) + + + # 3. State Space Model sequence transformation + + x_dbl = self.x_proj(torch.rearrange(hidden_states, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = torch.rearrange(dt, "d (b l) -> b d l", l=seqlen) + A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + B = torch.rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = torch.rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + + + dt = nn.functional.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) + dB = torch.einsum("bd,bn->bdn", dt, B) + + if infer_params is not None: + ssm_state = infer_params.update_state_space(hidden_states, dA, dB) + + ssm_state.copy_(ssm_state * dA + torch.rearrange(hidden_states, "b d -> b d 1") * dB) + y = torch.einsum("bdn,bn->bd", ssm_state.to(hidden_states.dtype), C) + y = y + self.D.to(hidden_states.dtype) * hidden_states + y = y * self.act(gate) # (B D) + + # 4. Final linear projection + + attn_outputs = self.out_proj(y) + return attn_outputs, conv_state, y class MambaBlock(nn.Module): @@ -660,7 +712,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(x for x in [hidden_states, state, all_hidden_states, all_self_attentions] if x is not None) + return tuple(hidden_states for hidden_states in [hidden_states, state, all_hidden_states, all_self_attentions] if hidden_states is not None) return MambaOutput( last_hidden_state=hidden_states, From 04c991ab72a99ffaa2ad83d0c53507418db79b5e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 5 Feb 2024 17:04:17 +0900 Subject: [PATCH 008/116] add comments --- .../models/mamba/modeling_mamba.py | 39 ++++++++++++------- 1 file changed, 24 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 28fb27c73eac3e..190ee5523180a8 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -345,13 +345,21 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): class MambaCache: def __init__(self): - pass + self.conv_state = None + self.ssm_state = None + def update_conv_state(self, hidden_states): + self.conv_state.copy_(torch.roll(self.conv_state, shifts=-1, dims=-1)) # Update state (B D W) + self.conv_state[:, :, -1] = hidden_states + return self.conv_state + + def update_ssm_state(self, ssm_state): + self.ssm_state.copy_(ssm_state) class MambaSlowMixer(MambaMixer): def forward(self, hidden_states, infer_params=MambaCache()): - batch_size, seqlen, _ = hidden_states.shape + batch_size, seq_len, _ = hidden_states.shape # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1,2) @@ -359,39 +367,40 @@ def forward(self, hidden_states, infer_params=MambaCache()): # 2. Convolution sequence transformation if infer_params is not None: - conv_state = infer_params.update_conv_states(hidden_states, ssm_state) + conv_state = infer_params.update_conv_states(hidden_states) hidden_states = torch.sum(conv_state * torch.rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) if self.conv1d.bias is not None: hidden_states = hidden_states + self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype=hidden_states.dtype) - # 3. State Space Model sequence transformation - + # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(torch.rearrange(hidden_states, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() - dt = torch.rearrange(dt, "d (b l) -> b d l", l=seqlen) + dt = torch.rearrange(dt, "d (b l) -> b d l", l=seq_len) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - B = torch.rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = torch.rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + B = torch.rearrange(B, "(b l) dstate -> b dstate l", l=seq_len).contiguous() + C = torch.rearrange(C, "(b l) dstate -> b dstate l", l=seq_len).contiguous() - + # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) dt = nn.functional.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) dB = torch.einsum("bd,bn->bdn", dt, B) - if infer_params is not None: - ssm_state = infer_params.update_state_space(hidden_states, dA, dB) - - ssm_state.copy_(ssm_state * dA + torch.rearrange(hidden_states, "b d -> b d 1") * dB) - y = torch.einsum("bdn,bn->bd", ssm_state.to(hidden_states.dtype), C) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + ys = [] + for i in range(seq_len): + self.ssm_state.copy_(self.ssm_state * dA + torch.rearrange(hidden_states, "b d -> b d 1") * dB) + y = torch.einsum(self.ssm_state, C[:, i, :], 'b d_in n, b n -> b d_in') + ys.append(y) + y = torch.stack(ys, dim=1) # shape (b, l, d_in) + y = y + self.D.to(hidden_states.dtype) * hidden_states y = y * self.act(gate) # (B D) # 4. Final linear projection - attn_outputs = self.out_proj(y) return attn_outputs, conv_state, y From aa7e8d2b506191adf86310f04018fd840008ac11 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 5 Feb 2024 17:18:06 +0900 Subject: [PATCH 009/116] styling --- src/transformers/__init__.py | 14 +++---- src/transformers/models/__init__.py | 2 +- .../models/mamba/configuration_mamba.py | 1 + .../mamba/convert_mamba_checkpoint_to_hf.py | 2 +- .../models/mamba/modeling_mamba.py | 42 +++++++++---------- 5 files changed, 30 insertions(+), 31 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dc9a4766e1d91d..fc5f655da62b79 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -5306,6 +5306,7 @@ LxmertTokenizer, ) from .models.m2m_100 import M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP, M2M100Config + from .models.mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig from .models.marian import MarianConfig from .models.markuplm import ( MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -5484,7 +5485,6 @@ RoFormerTokenizer, ) from .models.rwkv import RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP, RwkvConfig - from .models.mamba import MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP, MambaConfig from .models.sam import ( SAM_PRETRAINED_CONFIG_ARCHIVE_MAP, SamConfig, @@ -7076,6 +7076,12 @@ M2M100Model, M2M100PreTrainedModel, ) + from .models.mamba import ( + MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST, + MambaForCausalLM, + MambaModel, + MambaPreTrainedModel, + ) from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel from .models.markuplm import ( MARKUPLM_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -7550,12 +7556,6 @@ RwkvModel, RwkvPreTrainedModel, ) - from .models.mamba import ( - MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST, - MambaForCausalLM, - MambaModel, - MambaPreTrainedModel, - ) from .models.sam import ( SAM_PRETRAINED_MODEL_ARCHIVE_LIST, SamModel, diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 0249a8ec993b70..bc3d271f3961c5 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -127,6 +127,7 @@ luke, lxmert, m2m_100, + mamba, marian, markuplm, mask2former, @@ -189,7 +190,6 @@ roc_bert, roformer, rwkv, - mamba, sam, seamless_m4t, seamless_m4t_v2, diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 8c83cbe02ed503..8d8efd66708f0c 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -15,6 +15,7 @@ # limitations under the License. """ MAMBA configuration""" import math + from ...configuration_utils import PretrainedConfig from ...utils import logging diff --git a/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py b/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py index 914dab0229f3b6..7f45e24ab77f7f 100644 --- a/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py +++ b/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py @@ -24,7 +24,7 @@ import torch from huggingface_hub import hf_hub_download -from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, MambaConfig +from transformers import AutoModelForCausalLM, AutoTokenizer, MambaConfig, PreTrainedTokenizerFast from transformers.modeling_utils import WEIGHTS_INDEX_NAME, shard_checkpoint diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 190ee5523180a8..51fa441d008174 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -31,9 +31,6 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, - is_bitsandbytes_available, - is_ninja_available, - is_torch_cuda_available, logging, ) from .configuration_mamba import MambaConfig @@ -249,7 +246,7 @@ def mamba_linear_attention(time_decay, time_first, key, value, state=None, retur class MambaMixer(nn.Module): - + def __init__(self, config, layer_idx): super().__init__() self.d_model = config.d_model @@ -274,12 +271,12 @@ def __init__(self, config, layer_idx): self.activation = config.activation self.act = ACT2FN[config.activation] - + # selective projection used to make dt, B and C input dependant self.x_proj = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) self.time_step_proj = nn.Linear(self.time_step_rank, self.d_inner, bias=True) - # S4D real initialization. These are not discretized! - # THe core is to load them, compute the discrete states, then write the updates state. + # S4D real initialization. These are not discretized! + # THe core is to load them, compute the discrete states, then write the updates state. # Keeps the memory bounded what_is_this = torch.arange(1, self.d_state + 1, dtype=torch.float32) A = torch.repeat(what_is_this,d=self.d_inner).contiguous() @@ -302,11 +299,11 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): projected_states = self.in_proj(hidden_states).transpose(1,2) hidden_states, z = projected_states.chunk(2, dim=1) - + if inference_params is not None and inference_params.seq_offset > 0: hidden_states = causal_conv1d_update( hidden_states, - conv_state, + conv_state, self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), self.conv1d.bias, self.activation, @@ -340,10 +337,10 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): y = rearrange(y, "b d l -> b l d") attn_outputs = self.out_proj(y) return attn_outputs, conv_state, last_state - + class MambaCache: - + def __init__(self): self.conv_state = None self.ssm_state = None @@ -357,23 +354,24 @@ def update_ssm_state(self, ssm_state): self.ssm_state.copy_(ssm_state) class MambaSlowMixer(MambaMixer): - + def forward(self, hidden_states, infer_params=MambaCache()): batch_size, seq_len, _ = hidden_states.shape - + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1,2) hidden_states, gate = projected_states.chunk(2, dim=1) - + # 2. Convolution sequence transformation if infer_params is not None: conv_state = infer_params.update_conv_states(hidden_states) + # TODO replace with simple conv call hidden_states = torch.sum(conv_state * torch.rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) if self.conv1d.bias is not None: hidden_states = hidden_states + self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype=hidden_states.dtype) - + # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(torch.rearrange(hidden_states, "b d l -> (b l) d")) # (bl d) @@ -383,28 +381,28 @@ def forward(self, hidden_states, infer_params=MambaCache()): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) B = torch.rearrange(B, "(b l) dstate -> b dstate l", l=seq_len).contiguous() C = torch.rearrange(C, "(b l) dstate -> b dstate l", l=seq_len).contiguous() - + # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) dt = nn.functional.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) dB = torch.einsum("bd,bn->bdn", dt, B) - + # 3.c perform the recurrence y ← SSM(A, B, C)(x) - ys = [] + ys = [] for i in range(seq_len): self.ssm_state.copy_(self.ssm_state * dA + torch.rearrange(hidden_states, "b d -> b d 1") * dB) y = torch.einsum(self.ssm_state, C[:, i, :], 'b d_in n, b n -> b d_in') ys.append(y) y = torch.stack(ys, dim=1) # shape (b, l, d_in) - + y = y + self.D.to(hidden_states.dtype) * hidden_states y = y * self.act(gate) # (B D) - + # 4. Final linear projection attn_outputs = self.out_proj(y) return attn_outputs, conv_state, y - + class MambaBlock(nn.Module): def __init__(self, config, layer_id): super().__init__() @@ -419,7 +417,7 @@ def forward(self, hidden_states, residual=None, inference_params=None): hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) - + hidden_states = self.mixer(hidden_states, inference_params=inference_params) outputs = (hidden_states, residual) return outputs From 26748c4a591566c0d12e182db03c8df288a98223 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 5 Feb 2024 18:10:47 +0900 Subject: [PATCH 010/116] nit --- src/transformers/models/mamba/modeling_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 51fa441d008174..3d39b66970b925 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -234,7 +234,7 @@ def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, r return output, state -def mamba_linear_attention(time_decay, time_first, key, value, state=None, return_state=False): +def mamba_mixer_forward(time_decay, time_first, key, value, state=None, return_state=False): no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value]) # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version # in this case). @@ -242,7 +242,7 @@ def mamba_linear_attention(time_decay, time_first, key, value, state=None, retur if mamba_cuda_kernel is None or no_cuda or one_token: return mamba_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state) else: - return MambaLinearAttention.apply(time_decay, time_first, key, value, state, return_state) + return MambaMixer.apply(time_decay, time_first, key, value, state, return_state) class MambaMixer(nn.Module): From 75e376a36e7f4cb474e704f1e43d060ea7ae28c2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 01:25:09 +0100 Subject: [PATCH 011/116] nits --- .../models/mamba/configuration_mamba.py | 10 +- .../models/mamba/modeling_mamba.py | 107 ++++++++++-------- 2 files changed, 71 insertions(+), 46 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 8d8efd66708f0c..de55b6a0561557 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -102,6 +102,10 @@ def __init__( expand=2, dt_rank="auto", tie_word_embeddings=True, + use_bias=False, + use_conv_bias=False, + hidden_act="silu", + initializer_range=0.1, **kwargs, ): self.vocab_size = vocab_size @@ -110,13 +114,17 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.layer_norm_epsilon = layer_norm_epsilon self.d_inner = hidden_size * 2 - self.d_conv = 4 + self.conv_kernel = 4 self.state_size = state_size self.expand = expand self.time_step_rank = math.ceil(self.hidden_size / 16) if dt_rank == "auto" else dt_rank self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id + self.use_bias = use_bias + self.use_conv_bias = use_conv_bias + self.hidden_act=hidden_act + self.initializer_range = initializer_range super().__init__( tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 3d39b66970b925..44d924ce1d3ba0 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -249,28 +249,28 @@ class MambaMixer(nn.Module): def __init__(self, config, layer_idx): super().__init__() - self.d_model = config.d_model - self.d_state = config.d_state - self.d_conv = config.d_conv + self.d_model = config.hidden_size + self.d_state = config.state_size + self.d_conv = config.conv_kernel self.expand = config.expand self.d_inner = int(self.expand * self.d_model) self.time_step_rank = math.ceil(self.d_model / 16) if config.time_step_rank == "auto" else config.time_step_rank - self.use_fast_path = config.use_fast_path + # self.use_fast_path = config.use_fast_path self.layer_idx = layer_idx - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=config.bias) + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=config.use_bias) self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, - bias=config.conv_bias, - kernel_size=config.d_conv, + bias=config.use_conv_bias, + kernel_size=config.conv_kernel, groups=self.d_inner, - padding=config.d_conv - 1, + padding=config.conv_kernel - 1, ) - self.activation = config.activation - self.act = ACT2FN[config.activation] + self.activation = config.hidden_act + self.act = ACT2FN[config.hidden_act] # selective projection used to make dt, B and C input dependant self.x_proj = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) @@ -279,7 +279,7 @@ def __init__(self, config, layer_idx): # THe core is to load them, compute the discrete states, then write the updates state. # Keeps the memory bounded what_is_this = torch.arange(1, self.d_state + 1, dtype=torch.float32) - A = torch.repeat(what_is_this,d=self.d_inner).contiguous() + A = what_is_this.repeat(self.d_inner).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True @@ -287,7 +287,7 @@ def __init__(self, config, layer_idx): # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 self.D._no_weight_decay = True - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.bias) + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.use_conv_bias) def forward(self, hidden_states: torch.Tensor, inference_params=None): """ @@ -296,7 +296,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): """ _, seqlen, _ = hidden_states.shape conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - + return None projected_states = self.in_proj(hidden_states).transpose(1,2) hidden_states, z = projected_states.chunk(2, dim=1) @@ -367,9 +367,10 @@ def forward(self, hidden_states, infer_params=MambaCache()): conv_state = infer_params.update_conv_states(hidden_states) # TODO replace with simple conv call - hidden_states = torch.sum(conv_state * torch.rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - if self.conv1d.bias is not None: - hidden_states = hidden_states + self.conv1d.bias + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + # hidden_states = torch.sum(conv_state * torch.rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + # if self.conv1d.bias is not None: + # hidden_states = hidden_states + self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype=hidden_states.dtype) # 3. State Space Model sequence transformation @@ -377,13 +378,17 @@ def forward(self, hidden_states, infer_params=MambaCache()): x_dbl = self.x_proj(torch.rearrange(hidden_states, "b d l -> (b l) d")) # (bl d) dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) dt = self.dt_proj.weight @ dt.t() - dt = torch.rearrange(dt, "d (b l) -> b d l", l=seq_len) + + dt = dt.transpose(0,1) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - B = torch.rearrange(B, "(b l) dstate -> b dstate l", l=seq_len).contiguous() - C = torch.rearrange(C, "(b l) dstate -> b dstate l", l=seq_len).contiguous() + + B = B.permute(0,2,1).contiguous() + C = C.permute(0,2,1).contiguous() # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) dt = nn.functional.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) + + # TODO replace einsums dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) dB = torch.einsum("bd,bn->bdn", dt, B) @@ -403,20 +408,49 @@ def forward(self, hidden_states, infer_params=MambaCache()): return attn_outputs, conv_state, y + _xz = self.in_proj(hidden_states) + _x, _z = _xz.chunk(2, dim=-1) # (B D) + conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) + conv_out = causal_conv1d_fn( + x=conv_state_new, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation + ) + conv_state = conv_state_new[:, :, 1:] + bsz, seqlen, dim = hidden_states.shape + output_tensor = torch.zeros( + (bsz, seqlen, dim), + device=hidden_states.device, + dtype=hidden_states.dtype + ) + for i in range(0, bsz): + x = conv_out[i:i+1,:,-1] + z = _z[i:i+1, -1, :] + x_db = self.x_proj(x) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = F.linear(dt, self.dt_proj.weight) + y = selective_state_update( + ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + out = self.out_proj(y) + output_tensor[i] = out + + class MambaBlock(nn.Module): - def __init__(self, config, layer_id): + def __init__(self, config, layer_idx): super().__init__() self.config = config - self.layer_id = layer_id - self.residual_in_fp32 = config.residual_in_fp32 + self.layer_idx = layer_idx + # self.residual_in_fp32 = config.residual_in_fp32 self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = MambaMixer(config.hidden_size, eps=config.layer_norm_epsilon) + self.mixer = MambaMixer(config, layer_idx = layer_idx) def forward(self, hidden_states, residual=None, inference_params=None): residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) + # if self.residual_in_fp32: + # residual = residual.to(torch.float32) hidden_states = self.mixer(hidden_states, inference_params=inference_params) outputs = (hidden_states, residual) @@ -482,25 +516,8 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=self.config.initializer_range) - if self.config.rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight", "fc2.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # We need to reinit p since this code could be called multiple times - # Having just p *= scale would repeatedly scale it down - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - with torch.no_grad(): - p /= math.sqrt(self.config.n_residuals_per_layer * self.config.n_layer) - - def _setup_cache(self, batch_size, max_seqlen, dtype): - raise NotImplementedError + + @dataclass class MambaOutput(ModelOutput): @@ -633,7 +650,7 @@ def __init__(self, config): super().__init__(config) self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([MambaBlock(config, layer_id=idx) for idx in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) self.norm_f = nn.LayerNorm(config.hidden_size) # ir use ALL_LAYER_NORM[config.hidden_states] self.layers_are_rescaled = False From 1c104b51b2bef63c80a526bb7ae5f2fc80cd2345 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 12:39:39 +0900 Subject: [PATCH 012/116] Style --- src/transformers/models/mamba/modeling_mamba.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 3d39b66970b925..116c8683d6826e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -660,6 +660,7 @@ def forward( attention_mask: Optional[torch.LongTensor] = None, # noqa inputs_embeds: Optional[torch.FloatTensor] = None, state: Optional[List[torch.FloatTensor]] = None, + inference_params: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -681,7 +682,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) # TODO better to call _set_cache - if use_cache and cache is None: + if use_cache and inference_params is None: shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) dtype = inputs_embeds.dtype if i <= 1 else torch.float32 cache = [torch.zeros(*shape, dtype=dtype, device=inputs_embeds.device)for i in range(5)] From a804466579f4d7c95ada664f0d2c6b7f5e21b50b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 13:05:53 +0900 Subject: [PATCH 013/116] Small changes --- .../models/mamba/modeling_mamba.py | 80 +++++-------------- 1 file changed, 20 insertions(+), 60 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index c7e3b2e86030b4..05ac4459735eba 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -355,7 +355,7 @@ def update_ssm_state(self, ssm_state): class MambaSlowMixer(MambaMixer): - def forward(self, hidden_states, infer_params=MambaCache()): + def forward(self, hidden_states, inference_params=MambaCache()): batch_size, seq_len, _ = hidden_states.shape # 1. Gated MLP's linear projection @@ -363,8 +363,8 @@ def forward(self, hidden_states, infer_params=MambaCache()): hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation - if infer_params is not None: - conv_state = infer_params.update_conv_states(hidden_states) + if inference_params is not None: + conv_state = inference_params.update_conv_states(hidden_states) # TODO replace with simple conv call hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) @@ -457,7 +457,6 @@ def forward(self, hidden_states, residual=None, inference_params=None): return outputs -# Copied from transformers.models.mamba.modeling_mamba.mambaPreTrainedModel with mamba->Mamba,mamba->mamba class MambaPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -467,48 +466,12 @@ class MambaPreTrainedModel(PreTrainedModel): config_class = MambaConfig base_model_prefix = "mamba" _no_split_modules = ["MambaBlock"] - _keep_in_fp32_modules = ["time_decay", "time_first"] supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, MambaMixer): - layer_id = module.layer_id - num_hidden_layers = module.config.num_hidden_layers - hidden_size = module.config.hidden_size - attention_hidden_size = module.attention_hidden_size - - ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1 - ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 - - time_weight = torch.tensor( - [i / hidden_size for i in range(hidden_size)], - dtype=module.time_mix_key.dtype, - device=module.time_mix_key.device, - ) - time_weight = time_weight[None, None, :] - - decay_speed = [ - -5 + 8 * (h / (attention_hidden_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) - for h in range(attention_hidden_size) - ] - decay_speed = torch.tensor(decay_speed, dtype=module.time_decay.dtype, device=module.time_decay.device) - zigzag = ( - torch.tensor( - [(i + 1) % 3 - 1 for i in range(attention_hidden_size)], - dtype=module.time_first.dtype, - device=module.time_first.device, - ) - * 0.5 - ) - - with torch.no_grad(): - module.time_decay.data = decay_speed - module.time_first.data = torch.ones_like(module.time_first * math.log(0.3) + zigzag) - - module.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) - module.time_mix_value.data = torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 - module.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0) + pass if isinstance(module, nn.Linear): if module.bias is not None: if not getattr(module.bias, "_no_reinit", False): @@ -527,7 +490,7 @@ class MambaOutput(ModelOutput): Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + inference_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): @@ -544,7 +507,7 @@ class MambaOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor = None - state: Optional[List[torch.FloatTensor]] = None + inference_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -559,7 +522,7 @@ class MambaCausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - state (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + inference_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): @@ -577,7 +540,7 @@ class MambaCausalLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - state: Optional[List[torch.FloatTensor]] = None + inference_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None @@ -612,20 +575,11 @@ class MambaCausalLMOutput(ModelOutput): [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): - Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - - This is currently not used by `MambaModel`, but will be supported in the future. - - [What are attention masks?](../glossary#attention-mask) inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - state (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*): + inference_params (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): @@ -674,8 +628,6 @@ def set_input_embeddings(self, new_embeddings): def forward( self, input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.LongTensor] = None, # noqa - inputs_embeds: Optional[torch.FloatTensor] = None, state: Optional[List[torch.FloatTensor]] = None, inference_params: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, @@ -772,18 +724,24 @@ def get_output_embeddings(self): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings - def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): + def get_input_embeddings(self): + return self.backbone.get_input_embeddings() + + def set_input_embeddings(self, new_embeddings): + return self.backbone.set_input_embeddings(new_embeddings) + + def prepare_inputs_for_generation(self, input_ids, inference_params=None, inputs_embeds=None,attention_mask=None,**kwargs): # only last token for inputs_ids if the state is passed along. if state is not None: input_ids = input_ids[:, -1].unsqueeze(-1) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step - if inputs_embeds is not None and state is None: + if inputs_embeds is not None and inference_params is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} - model_inputs["state"] = state + model_inputs["inference_params"] = inference_params return model_inputs @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING) @@ -796,6 +754,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, + inference_params: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -811,6 +770,7 @@ def forward( mamba_outputs = self.mamba( input_ids, + inference_params=inference_params, inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, From 6b87ad2c106325b178512035e4ce4680fe2e549c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 14 Feb 2024 17:30:44 +0900 Subject: [PATCH 014/116] Push dummy mambda simple slow --- .../models/mamba/modeling_mamba.py | 161 ++++++++---------- 1 file changed, 75 insertions(+), 86 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 05ac4459735eba..4038d6f3529e6c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -258,7 +258,7 @@ def __init__(self, config, layer_idx): # self.use_fast_path = config.use_fast_path self.layer_idx = layer_idx - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=config.use_bias) + self.conv1d = nn.Conv1d( in_channels=self.d_inner, @@ -272,14 +272,16 @@ def __init__(self, config, layer_idx): self.activation = config.hidden_act self.act = ACT2FN[config.hidden_act] + # projection of the input hidden states + self.input_projection = nn.Linear(self.d_model, self.d_inner * 2, bias=config.use_bias) # selective projection used to make dt, B and C input dependant - self.x_proj = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) - self.time_step_proj = nn.Linear(self.time_step_rank, self.d_inner, bias=True) + self.discrete_projection = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) + # time step projection (discretization) + self.time_step_projection = nn.Linear(self.time_step_rank, self.d_inner, bias=True) # S4D real initialization. These are not discretized! # THe core is to load them, compute the discrete states, then write the updates state. # Keeps the memory bounded - what_is_this = torch.arange(1, self.d_state + 1, dtype=torch.float32) - A = what_is_this.repeat(self.d_inner).contiguous() + A = torch.arange(1, self.d_state + 1, dtype=torch.float32)[None,:].expand(self.d_inner, -1).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True @@ -295,8 +297,8 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): Returns: same shape as hidden_states """ _, seqlen, _ = hidden_states.shape - conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - return None + # conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + projected_states = self.in_proj(hidden_states).transpose(1,2) hidden_states, z = projected_states.chunk(2, dim=1) @@ -355,86 +357,74 @@ def update_ssm_state(self, ssm_state): class MambaSlowMixer(MambaMixer): - def forward(self, hidden_states, inference_params=MambaCache()): + def forward(self, hidden_states, inference_params=None): + """ + + Compute ∆ A B C D, the state space parameters. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + Args: + hidden_states: + inference_params: + + Returns: + + """ batch_size, seq_len, _ = hidden_states.shape # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states).transpose(1,2) + projected_states = self.input_projection(hidden_states).transpose(1,2) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation if inference_params is not None: conv_state = inference_params.update_conv_states(hidden_states) - # TODO replace with simple conv call + # conv_state.copy_(self.conv1d(hidden_states)[..., :seq_len]) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # hidden_states = torch.sum(conv_state * torch.rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) # if self.conv1d.bias is not None: # hidden_states = hidden_states + self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype=hidden_states.dtype) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - x_dbl = self.x_proj(torch.rearrange(hidden_states, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - - dt = dt.transpose(0,1) + x_dbl = self.discrete_projection(hidden_states.transpose(1,2)) + time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) + discrete_time_step = self.time_step_projection(time_step) + + # discrete_time_step = discrete_time_step.transpose(0,1) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) - - B = B.permute(0,2,1).contiguous() - C = C.permute(0,2,1).contiguous() # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) - dt = nn.functional.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype)) - - # TODO replace einsums - dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A)) - dB = torch.einsum("bd,bn->bdn", dt, B) - + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1,2) + # [batch_size, d, l, 1] X [1, d, 1, n] -> [batch_size, d, l, n] + dA = torch.exp(discrete_time_step[:, :, :, None] * A[None, :, None, :]) + # [batch_size, d, l, 1] [b, d, l, 1] -> [batch_size, d, l, 1] X [batch_size, 1, l, n] -> [batch_size, d, l, n] + deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :] + + ssm_state = torch.zeros((batch_size, self.d_inner, self.d_state), device=A.device) + # ssm_state = inference_params.ssm_state # 3.c perform the recurrence y ← SSM(A, B, C)(x) + ys = [] for i in range(seq_len): - self.ssm_state.copy_(self.ssm_state * dA + torch.rearrange(hidden_states, "b d -> b d 1") * dB) - y = torch.einsum(self.ssm_state, C[:, i, :], 'b d_in n, b n -> b d_in') - ys.append(y) - y = torch.stack(ys, dim=1) # shape (b, l, d_in) + ssm_state.copy_(ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :]) + # [b, d, n] X [b, n] -> [b, d] + y = torch.matmul(ssm_state, C[:,i,:].unsqueeze(-1)) + ys.append(y[:,:,0]) + y = torch.stack(ys, dim=1) # shape (b, l, d) - y = y + self.D.to(hidden_states.dtype) * hidden_states - y = y * self.act(gate) # (B D) + y = y + (hidden_states * self.D.to(hidden_states.dtype)[None,:,None]).transpose(1,2) + y = y * self.act(gate).transpose(1,2) # (B D) # 4. Final linear projection attn_outputs = self.out_proj(y) - return attn_outputs, conv_state, y - + return attn_outputs, None, ssm_state, y + return attn_outputs, conv_state, ssm_state, y - _xz = self.in_proj(hidden_states) - _x, _z = _xz.chunk(2, dim=-1) # (B D) - conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) - conv_out = causal_conv1d_fn( - x=conv_state_new, - weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), - bias=self.conv1d.bias, - activation=self.activation - ) - conv_state = conv_state_new[:, :, 1:] - bsz, seqlen, dim = hidden_states.shape - output_tensor = torch.zeros( - (bsz, seqlen, dim), - device=hidden_states.device, - dtype=hidden_states.dtype - ) - for i in range(0, bsz): - x = conv_out[i:i+1,:,-1] - z = _z[i:i+1, -1, :] - x_db = self.x_proj(x) - dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = F.linear(dt, self.dt_proj.weight) - y = selective_state_update( - ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True - ) - out = self.out_proj(y) - output_tensor[i] = out class MambaBlock(nn.Module): @@ -444,17 +434,17 @@ def __init__(self, config, layer_idx): self.layer_idx = layer_idx # self.residual_in_fp32 = config.residual_in_fp32 self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = MambaMixer(config, layer_idx = layer_idx) + self.mixer = MambaSlowMixer(config, layer_idx = layer_idx) - def forward(self, hidden_states, residual=None, inference_params=None): - residual = (hidden_states + residual) if residual is not None else hidden_states + def forward(self, hidden_states, inference_params=None): + residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) # if self.residual_in_fp32: # residual = residual.to(torch.float32) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) - outputs = (hidden_states, residual) - return outputs + hidden_states, con_states, ssm_state, y = self.mixer(hidden_states, inference_params=inference_params) + hidden_states = residual + hidden_states + return hidden_states class MambaPreTrainedModel(PreTrainedModel): @@ -605,7 +595,7 @@ def __init__(self, config): self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) - self.norm_f = nn.LayerNorm(config.hidden_size) # ir use ALL_LAYER_NORM[config.hidden_states] + self.norm_f = nn.LayerNorm(config.hidden_size) self.layers_are_rescaled = False self.gradient_checkpointing = False @@ -628,7 +618,7 @@ def set_input_embeddings(self, new_embeddings): def forward( self, input_ids: Optional[torch.LongTensor] = None, - state: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, inference_params: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -642,7 +632,6 @@ def forward( use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is None and inputs_embeds is None: @@ -650,6 +639,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) + # TODO better to call _set_cache if use_cache and inference_params is None: shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) @@ -665,37 +655,36 @@ def forward( hidden_states = inputs_embeds - all_self_attentions = () if output_attentions else None + + all_last_states = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for idx, layer in enumerate(self.layers): + ssm_state = None if self.gradient_checkpointing and self.training: - hidden_states, cache, partial_states = self._gradient_checkpointing_func( - layer.__call__, hidden_states, state, use_cache, output_attentions - ) + hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states, inference_params) else: - hidden_states, cache, partial_states = layer( - hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions - ) + hidden_states = layer(hidden_states, inference_params=inference_params) + # inference_params.conv_state_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if output_attentions: - all_self_attentions = all_self_attentions + (partial_states,) + all_self_attentions = all_last_states + (ssm_state,) - hidden_states = self.ln_out(hidden_states) + hidden_states = self.norm_f(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(hidden_states for hidden_states in [hidden_states, state, all_hidden_states, all_self_attentions] if hidden_states is not None) + return tuple(hidden_states for hidden_states in [hidden_states, inference_params, all_hidden_states, all_last_states] if hidden_states is not None) return MambaOutput( last_hidden_state=hidden_states, - state=cache, + inference_params=inference_params, hidden_states=all_hidden_states, - attentions=all_self_attentions, + attentions=all_last_states, ) @@ -730,9 +719,9 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): return self.backbone.set_input_embeddings(new_embeddings) - def prepare_inputs_for_generation(self, input_ids, inference_params=None, inputs_embeds=None,attention_mask=None,**kwargs): + def prepare_inputs_for_generation(self, input_ids, inference_params=None, inputs_embeds=None, attention_mask=None, **kwargs): # only last token for inputs_ids if the state is passed along. - if state is not None: + if inference_params is not None: input_ids = input_ids[:, -1].unsqueeze(-1) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step @@ -768,7 +757,7 @@ def forward( """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict - mamba_outputs = self.mamba( + mamba_outputs = self.backbone( input_ids, inference_params=inference_params, inputs_embeds=inputs_embeds, @@ -778,7 +767,7 @@ def forward( ) hidden_states = mamba_outputs[0] - logits = self.head(hidden_states) + logits = self.lm_head(hidden_states) loss = None if labels is not None: @@ -798,7 +787,7 @@ def forward( return MambaCausalLMOutput( loss=loss, logits=logits, - cache=mamba_outputs.cache, + inference_params=mamba_outputs.inference_params, hidden_states=mamba_outputs.hidden_states, attentions=mamba_outputs.attentions, ) From a7ec8d637f108f8a9231f4076c0952a8624966fa Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 15 Feb 2024 08:24:14 +0900 Subject: [PATCH 015/116] nit --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 4038d6f3529e6c..3f846a26bf4434 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -643,7 +643,7 @@ def forward( # TODO better to call _set_cache if use_cache and inference_params is None: shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) - dtype = inputs_embeds.dtype if i <= 1 else torch.float32 + dtype = inputs_embeds.dtype cache = [torch.zeros(*shape, dtype=dtype, device=inputs_embeds.device)for i in range(5)] if self.gradient_checkpointing and self.training: From 50464518f7e6fb1d669b906296f10c7e9e5e0366 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 15 Feb 2024 08:35:54 +0900 Subject: [PATCH 016/116] Use original names --- src/transformers/models/mamba/modeling_mamba.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 3f846a26bf4434..d821ab723a773d 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -273,11 +273,11 @@ def __init__(self, config, layer_idx): self.act = ACT2FN[config.hidden_act] # projection of the input hidden states - self.input_projection = nn.Linear(self.d_model, self.d_inner * 2, bias=config.use_bias) + self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=config.use_bias) # selective projection used to make dt, B and C input dependant - self.discrete_projection = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) + self.x_proj = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) # time step projection (discretization) - self.time_step_projection = nn.Linear(self.time_step_rank, self.d_inner, bias=True) + self.dt_proj = nn.Linear(self.time_step_rank, self.d_inner, bias=True) # S4D real initialization. These are not discretized! # THe core is to load them, compute the discrete states, then write the updates state. # Keeps the memory bounded @@ -375,7 +375,7 @@ def forward(self, hidden_states, inference_params=None): batch_size, seq_len, _ = hidden_states.shape # 1. Gated MLP's linear projection - projected_states = self.input_projection(hidden_states).transpose(1,2) + projected_states = self.in_proj(hidden_states).transpose(1,2) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation @@ -391,9 +391,9 @@ def forward(self, hidden_states, inference_params=None): # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - x_dbl = self.discrete_projection(hidden_states.transpose(1,2)) + x_dbl = self.x_proj(hidden_states.transpose(1,2)) time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) - discrete_time_step = self.time_step_projection(time_step) + discrete_time_step = self.dt_proj(time_step) # discrete_time_step = discrete_time_step.transpose(0,1) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) From b5831e3dad5954318132f9b68e67035f08078a25 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 15 Feb 2024 09:19:28 +0900 Subject: [PATCH 017/116] Use original names and remove norm --- .../models/mamba/modeling_mamba.py | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index d821ab723a773d..71b0aaba7bbb83 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -425,7 +425,22 @@ def forward(self, hidden_states, inference_params=None): return attn_outputs, None, ssm_state, y return attn_outputs, conv_state, ssm_state, y - +class MambaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + class MambaBlock(nn.Module): def __init__(self, config, layer_idx): @@ -433,7 +448,7 @@ def __init__(self, config, layer_idx): self.config = config self.layer_idx = layer_idx # self.residual_in_fp32 = config.residual_in_fp32 - self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mixer = MambaSlowMixer(config, layer_idx = layer_idx) def forward(self, hidden_states, inference_params=None): @@ -595,7 +610,6 @@ def __init__(self, config): self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) - self.norm_f = nn.LayerNorm(config.hidden_size) self.layers_are_rescaled = False self.gradient_checkpointing = False @@ -672,8 +686,6 @@ def forward( if output_attentions: all_self_attentions = all_last_states + (ssm_state,) - hidden_states = self.norm_f(hidden_states) - if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From e9a80ad87913bef9abf2fabdbf36769fbf53a772 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 15 Feb 2024 15:32:16 +0900 Subject: [PATCH 018/116] Updates for inference params --- .../models/mamba/modeling_mamba.py | 95 ++++++++++++------- 1 file changed, 60 insertions(+), 35 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 71b0aaba7bbb83..bdc39e10ee4eee 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union, Dict, Any import torch import torch.utils.checkpoint @@ -342,10 +342,19 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): class MambaCache: - - def __init__(self): - self.conv_state = None - self.ssm_state = None + def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torch.float32, device=None): + self.seqlen_offset = 0 + d_model = config.hidden_size + d_state = config.state_size + expand = config.expand + d_conv = config.conv_kernel + + self.conv_states = { i: torch.zeros( + batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype + ) for i in range(config.num_hidden_layers)} + self.ssm_states = { i: torch.zeros( + batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype + )for i in range(config.num_hidden_layers)} def update_conv_state(self, hidden_states): self.conv_state.copy_(torch.roll(self.conv_state, shifts=-1, dims=-1)) # Update state (B D W) @@ -379,15 +388,29 @@ def forward(self, hidden_states, inference_params=None): hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation - if inference_params is not None: - conv_state = inference_params.update_conv_states(hidden_states) + if inference_params.seqlen_offset > 0: + conv_state = inference_params.conv_states[self.layer_idx] + conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) + conv_state[:, :, -1] = hidden_states[:,:,0] + # out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) + # return out, conv_state, ssm_state + else: + conv_state = hidden_states + inference_params.conv_states[self.layer_idx].copy_(nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0))) + + ssm_state = inference_params.ssm_states[self.layer_idx] - # conv_state.copy_(self.conv1d(hidden_states)[..., :seq_len]) + # conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) + # conv_state[:, :, -1] = hidden_states - hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) - # hidden_states = torch.sum(conv_state * torch.rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) + # when you have the first iter, use conv_state + hidden_states = self.act(self.conv1d(conv_state)[..., :seq_len]) + + # x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) # if self.conv1d.bias is not None: - # hidden_states = hidden_states + self.conv1d.bias + # x = x + self.conv1d.bias + # x = self.act(x).to(dtype=dtype) + # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C @@ -405,25 +428,21 @@ def forward(self, hidden_states, inference_params=None): # [batch_size, d, l, 1] [b, d, l, 1] -> [batch_size, d, l, 1] X [batch_size, 1, l, n] -> [batch_size, d, l, n] deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :] - ssm_state = torch.zeros((batch_size, self.d_inner, self.d_state), device=A.device) - # ssm_state = inference_params.ssm_state # 3.c perform the recurrence y ← SSM(A, B, C)(x) - ys = [] for i in range(seq_len): ssm_state.copy_(ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :]) # [b, d, n] X [b, n] -> [b, d] y = torch.matmul(ssm_state, C[:,i,:].unsqueeze(-1)) ys.append(y[:,:,0]) - y = torch.stack(ys, dim=1) # shape (b, l, d) + y = torch.stack(ys, dim=-1) # shape (b, l, d) - y = y + (hidden_states * self.D.to(hidden_states.dtype)[None,:,None]).transpose(1,2) - y = y * self.act(gate).transpose(1,2) # (B D) + y = y + (hidden_states * self.D.to(hidden_states.dtype)[None,:,None]) + y = y * self.act(gate) # (B D) # 4. Final linear projection - attn_outputs = self.out_proj(y) - return attn_outputs, None, ssm_state, y - return attn_outputs, conv_state, ssm_state, y + attn_outputs = self.out_proj(y.transpose(1,2)) + return attn_outputs, conv_state, ssm_state class MambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -449,7 +468,7 @@ def __init__(self, config, layer_idx): self.layer_idx = layer_idx # self.residual_in_fp32 = config.residual_in_fp32 self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = MambaSlowMixer(config, layer_idx = layer_idx) + self.mixer = MambaSlowMixer(config, layer_idx=layer_idx) def forward(self, hidden_states, inference_params=None): residual = hidden_states @@ -457,9 +476,9 @@ def forward(self, hidden_states, inference_params=None): # if self.residual_in_fp32: # residual = residual.to(torch.float32) - hidden_states, con_states, ssm_state, y = self.mixer(hidden_states, inference_params=inference_params) + hidden_states, conv_states, ssm_state = self.mixer(hidden_states, inference_params=inference_params) hidden_states = residual + hidden_states - return hidden_states + return hidden_states, conv_states, ssm_state class MambaPreTrainedModel(PreTrainedModel): @@ -611,7 +630,6 @@ def __init__(self, config): self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) - self.layers_are_rescaled = False self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -654,11 +672,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) - # TODO better to call _set_cache if use_cache and inference_params is None: - shape = (inputs_embeds.size(0), self.config.hidden_size, self.config.num_hidden_layers) - dtype = inputs_embeds.dtype - cache = [torch.zeros(*shape, dtype=dtype, device=inputs_embeds.device)for i in range(5)] + inference_params = MambaCache(self.config, inputs_embeds.size(0), device=inputs_embeds.device) if self.gradient_checkpointing and self.training: if use_cache: @@ -668,23 +683,23 @@ def forward( use_cache = False hidden_states = inputs_embeds - - all_last_states = () if output_attentions else None all_hidden_states = () if output_hidden_states else None for idx, layer in enumerate(self.layers): - ssm_state = None if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func(layer.__call__, hidden_states, inference_params) + hidden_states, conv_state, ssm_state = self._gradient_checkpointing_func(layer.__call__, hidden_states, inference_params) else: - hidden_states = layer(hidden_states, inference_params=inference_params) - # inference_params.conv_state_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) + hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params) + # inference_params.update_conv_state(conv_state) + # inference_params.update_ssm_state(ssm_state) + inference_params.seqlen_offset += inputs_embeds.shape[1] + inference_params.ssm_states[idx].copy_(ssm_state) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if output_attentions: - all_self_attentions = all_last_states + (ssm_state,) + all_last_states = all_last_states + (ssm_state,) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -731,6 +746,16 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): return self.backbone.set_input_embeddings(new_embeddings) + def _update_model_kwargs_for_generation(self,outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + + model_kwargs["inference_params"] = outputs["inference_params"] + return model_kwargs + + def prepare_inputs_for_generation(self, input_ids, inference_params=None, inputs_embeds=None, attention_mask=None, **kwargs): # only last token for inputs_ids if the state is passed along. if inference_params is not None: From ee4a7ef0e4d4aa170b9a4f1f0bfd8c2f1469009c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 15 Feb 2024 16:19:20 +0900 Subject: [PATCH 019/116] Style nd updates --- .../models/mamba/modeling_mamba.py | 64 +++++++------------ 1 file changed, 22 insertions(+), 42 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index bdc39e10ee4eee..65f928f418599c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass from pathlib import Path -from typing import List, Optional, Tuple, Union, Dict, Any +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -349,20 +349,9 @@ def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torc expand = config.expand d_conv = config.conv_kernel - self.conv_states = { i: torch.zeros( - batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype - ) for i in range(config.num_hidden_layers)} - self.ssm_states = { i: torch.zeros( - batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype - )for i in range(config.num_hidden_layers)} + self.conv_states = { i: torch.zeros(batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype) for i in range(config.num_hidden_layers)} + self.ssm_states = { i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype)for i in range(config.num_hidden_layers)} - def update_conv_state(self, hidden_states): - self.conv_state.copy_(torch.roll(self.conv_state, shifts=-1, dims=-1)) # Update state (B D W) - self.conv_state[:, :, -1] = hidden_states - return self.conv_state - - def update_ssm_state(self, ssm_state): - self.ssm_state.copy_(ssm_state) class MambaSlowMixer(MambaMixer): @@ -391,34 +380,17 @@ def forward(self, hidden_states, inference_params=None): if inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) - conv_state[:, :, -1] = hidden_states[:,:,0] - # out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) - # return out, conv_state, ssm_state + conv_state[:, :, -1].copy_(hidden_states[:,:,0]) + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1).unsqueeze(-1) else: - conv_state = hidden_states inference_params.conv_states[self.layer_idx].copy_(nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0))) - - ssm_state = inference_params.ssm_states[self.layer_idx] - - # conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) # Update state (B D W) - # conv_state[:, :, -1] = hidden_states - - # when you have the first iter, use conv_state - hidden_states = self.act(self.conv1d(conv_state)[..., :seq_len]) - - # x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D) - # if self.conv1d.bias is not None: - # x = x + self.conv1d.bias - # x = self.act(x).to(dtype=dtype) - + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(hidden_states.transpose(1,2)) time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) discrete_time_step = self.dt_proj(time_step) - - # discrete_time_step = discrete_time_step.transpose(0,1) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) @@ -429,6 +401,7 @@ def forward(self, hidden_states, inference_params=None): deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :] # 3.c perform the recurrence y ← SSM(A, B, C)(x) + ssm_state = inference_params.ssm_states[self.layer_idx] ys = [] for i in range(seq_len): ssm_state.copy_(ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :]) @@ -436,13 +409,11 @@ def forward(self, hidden_states, inference_params=None): y = torch.matmul(ssm_state, C[:,i,:].unsqueeze(-1)) ys.append(y[:,:,0]) y = torch.stack(ys, dim=-1) # shape (b, l, d) - y = y + (hidden_states * self.D.to(hidden_states.dtype)[None,:,None]) y = y * self.act(gate) # (B D) - # 4. Final linear projection attn_outputs = self.out_proj(y.transpose(1,2)) - return attn_outputs, conv_state, ssm_state + return attn_outputs, None, ssm_state class MambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -502,8 +473,18 @@ def _init_weights(self, module): nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=self.config.initializer_range) - - + # + # # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max + # dt = torch.exp( + # torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + # + math.log(dt_min) + # ).clamp(min=dt_init_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + # inv_dt = dt + torch.log(-torch.expm1(-dt)) + # with torch.no_grad(): + # self.dt_proj.bias.copy_(inv_dt) + # # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit + # self.dt_proj.bias._no_reinit = True @dataclass @@ -690,16 +671,15 @@ def forward( hidden_states, conv_state, ssm_state = self._gradient_checkpointing_func(layer.__call__, hidden_states, inference_params) else: hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params) - # inference_params.update_conv_state(conv_state) - # inference_params.update_ssm_state(ssm_state) - inference_params.seqlen_offset += inputs_embeds.shape[1] inference_params.ssm_states[idx].copy_(ssm_state) + # inference_params.conv_states[idx].copy_(conv_state) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if output_attentions: all_last_states = all_last_states + (ssm_state,) + inference_params.seqlen_offset += inputs_embeds.shape[1] if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From d8c195fbdf486015cf758b656842f8b5c7e05da3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 15 Feb 2024 16:45:16 +0900 Subject: [PATCH 020/116] nits --- .../models/mamba/modeling_mamba.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 65f928f418599c..05a8468c60fccd 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -486,6 +486,30 @@ def _init_weights(self, module): # # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit # self.dt_proj.bias._no_reinit = True + # if isinstance(module, nn.Linear): + # if module.bias is not None: + # if not getattr(module.bias, "_no_reinit", False): + # nn.init.zeros_(module.bias) + # elif isinstance(module, nn.Embedding): + # nn.init.normal_(module.weight, std=initializer_range) + # + # if rescale_prenorm_residual: + # # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # # + # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + # for name, p in module.named_parameters(): + # if name in ["out_proj.weight", "fc2.weight"]: + # # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # # We need to reinit p since this code could be called multiple times + # # Having just p *= scale would repeatedly scale it down + # nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + # with torch.no_grad(): + # p /= math.sqrt(n_residuals_per_layer * n_layer) + @dataclass class MambaOutput(ModelOutput): From e64fedc29f594e86ea2f06682d2348a2aafc874c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Feb 2024 15:32:51 +0900 Subject: [PATCH 021/116] Match logits --- .../models/mamba/configuration_mamba.py | 2 +- src/transformers/models/mamba/modeling_mamba.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index de55b6a0561557..79b3e27568a62d 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -103,7 +103,7 @@ def __init__( dt_rank="auto", tie_word_embeddings=True, use_bias=False, - use_conv_bias=False, + use_conv_bias=True, hidden_act="silu", initializer_range=0.1, **kwargs, diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 05a8468c60fccd..f59e13d93c4eac 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -289,7 +289,7 @@ def __init__(self, config, layer_idx): # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 self.D._no_weight_decay = True - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.use_conv_bias) + self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.use_bias) def forward(self, hidden_states: torch.Tensor, inference_params=None): """ @@ -374,6 +374,7 @@ def forward(self, hidden_states, inference_params=None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1,2) + # self.in_proj(hidden_states.squeeze(0)) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation @@ -381,7 +382,10 @@ def forward(self, hidden_states, inference_params=None): conv_state = inference_params.conv_states[self.layer_idx] conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) conv_state[:, :, -1].copy_(hidden_states[:,:,0]) - hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1).unsqueeze(-1) + + hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + self.conv1d.bias) + hidden_states = hidden_states.unsqueeze(-1) + else: inference_params.conv_states[self.layer_idx].copy_(nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0))) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) @@ -448,7 +452,7 @@ def forward(self, hidden_states, inference_params=None): # residual = residual.to(torch.float32) hidden_states, conv_states, ssm_state = self.mixer(hidden_states, inference_params=inference_params) - hidden_states = residual + hidden_states + hidden_states = residual.to(torch.float32) + hidden_states return hidden_states, conv_states, ssm_state @@ -636,7 +640,7 @@ def __init__(self, config): self.layers = nn.ModuleList([MambaBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - + self.norm_f = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) # Initialize weights and apply final processing self.post_init() @@ -705,6 +709,8 @@ def forward( all_last_states = all_last_states + (ssm_state,) inference_params.seqlen_offset += inputs_embeds.shape[1] + hidden_states = self.norm_f(hidden_states) + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) From aee558f37bb54f605395723b718f1a59d905b3b3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Feb 2024 15:43:52 +0900 Subject: [PATCH 022/116] Add a test --- .../models/mamba/modeling_mamba.py | 5 ++-- tests/models/mamba/test_modeling_mamba.py | 26 ++++++++++++------- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index f59e13d93c4eac..650e21110189c0 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -353,7 +353,7 @@ def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torc self.ssm_states = { i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype)for i in range(config.num_hidden_layers)} -class MambaSlowMixer(MambaMixer): +class MambaMixerSlow(MambaMixer): def forward(self, hidden_states, inference_params=None): """ @@ -374,7 +374,6 @@ def forward(self, hidden_states, inference_params=None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1,2) - # self.in_proj(hidden_states.squeeze(0)) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation @@ -443,7 +442,7 @@ def __init__(self, config, layer_idx): self.layer_idx = layer_idx # self.residual_in_fp32 = config.residual_in_fp32 self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = MambaSlowMixer(config, layer_idx=layer_idx) + self.mixer = MambaMixerSlow(config, layer_idx=layer_idx) def forward(self, hidden_states, inference_params=None): residual = hidden_states diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 7c361bb252113e..ddba249b5e203d 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -421,23 +421,31 @@ def test_model_from_pretrained(self): self.assertIsNotNone(model) -@unittest.skipIf( - not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" -) -@slow class MAMBAIntegrationTests(unittest.TestCase): def setUp(self): self.model_id = "state-spaces/mamba-2.8b" self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) def test_simple_generate(self): - expected_output = "Hello my name is Jasmine and I am a newbie to the" - model = MambaForCausalLM.from_pretrained(self.model_id).to(torch_device) + from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer + import torch - input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) - output = model.generate(input_ids, max_new_tokens=10) - output_sentence = self.tokenizer.decode(output[0].tolist()) + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + tokenizer.pad_token = tokenizer.eos_token + + model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float16) + model.to(torch_device) + model.config.use_cache = True + input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(torch_device) + + logits = model(input_ids = input_ids) + + EXPECTED_LOGITS = torch.tensor([ -6.7070, -24.7656, -6.4766, -6.0078, -9.7812, -13.0703, -11.4688, -10.6562, -9.3359, -9.4766, -9.1719, -7.9102, -13.0469, -8.7266, -8.4297, -8.4766, -9.1094, -11.5234, -11.1250, -11.7812, -12.1562, -12.8359, -12.1797, -13.4062, -13.6406, -13.4141, -13.6562, -9.2344, -7.9805, -7.2188, -9.9219, -9.1719, -7.8438, -9.1250, -10.1094, -10.2344, -10.2266, -9.7578, -11.0000, -10.6406], device='cuda:0',dtype=torch.float16) # fmt: skip + + torch.testing.assert_allclose(logits, EXPECTED_LOGITS) + out = model.generate(input_ids, max_new_tokens=10) + print(tokenizer.batch_decode(out)) self.assertEqual(output_sentence, expected_output) def test_simple_generate_bf16(self): From eae5f4524f04833929f0e92c6d9d4ca5778bfa8c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Feb 2024 15:47:08 +0900 Subject: [PATCH 023/116] Add expected generated text --- tests/models/mamba/test_modeling_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index ddba249b5e203d..39bf0b138399ed 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -445,8 +445,8 @@ def test_simple_generate(self): torch.testing.assert_allclose(logits, EXPECTED_LOGITS) out = model.generate(input_ids, max_new_tokens=10) - print(tokenizer.batch_decode(out)) - self.assertEqual(output_sentence, expected_output) + output_sentence = tokenizer.decode(out[0,:]) + self.assertEqual(output_sentence, ["Hey how are you doing?\n\nI'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm"]) def test_simple_generate_bf16(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" From 1f8e8d0a2b817b6f49f99a8b83fb66de41c5ccbc Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Feb 2024 16:04:02 +0900 Subject: [PATCH 024/116] nits doc, imports and styling --- .../models/mamba/modeling_mamba.py | 228 +++--------------- tests/models/mamba/test_modeling_mamba.py | 7 +- 2 files changed, 42 insertions(+), 193 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 650e21110189c0..bcaf453ae7d3b9 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -24,6 +24,19 @@ from torch import nn from torch.nn import CrossEntropyLoss + +is_causal_conv_1d_available = None +if is_causal_conv_1d_available(): + from causal_conv1d import causal_conv1d_update, causal_conv1d_fn +else : + causal_conv1d_update, causal_conv1d_fn = None, None + +is_kernel_compiled = None +if is_kernel_compiled(): + from causal_conv1d import causal_conv1d_update, causal_conv1d_fn +else : + selective_scan_update, selective_scan_fn = None, None + from ...activations import ACT2FN from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -38,11 +51,11 @@ logger = logging.get_logger(__name__) -_CHECKPOINT_FOR_DOC = "state-spaces/mamba-2.8b" +_CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m" _CONFIG_FOR_DOC = "MambaConfig" MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "state-spaces/mamba-2.8b", + "state-spaces/mamba-130m", # See all Mamba models at https://huggingface.co/models?filter=mamba ] @@ -50,7 +63,6 @@ mamba_cuda_kernel = None - # Copied from transformers.models.mamba.modeling_mamba.load_mamba_cuda_kernel with mamba->MAMBA,mamba->mamba def load_mamba_cuda_kernel(context_length): from torch.utils.cpp_extension import load as load_kernel @@ -84,167 +96,6 @@ def load_mamba_cuda_kernel(context_length): mamba_cuda_kernel.max_seq_length = context_length -class MambaMixer(torch.autograd.Function): - @staticmethod - def forward(ctx, time_decay, time_first, key, value, state=None, return_state=False): - batch_size, seq_len, hidden_size = key.size() - if seq_len > mamba_cuda_kernel.max_seq_length: - raise ValueError( - f"Cannot process a batch with {seq_len} tokens at the same time, use a maximum of " - f"{mamba_cuda_kernel.max_seq_length} with this model." - ) - if batch_size * hidden_size % min(hidden_size, 32) != 0: - raise ValueError( - f"The product of batch size ({batch_size}) and hidden size ({hidden_size}) needs to be a round " - f"multiple of {min(hidden_size, 32)}." - ) - - ctx.input_dtype = key.dtype - - if ( - time_decay.device.type != "cuda" - or time_first.device.type != "cuda" - or key.device.type != "cuda" - or value.device.type != "cuda" - ): - raise ValueError("Calling the CUDA kernel for mamba attention requires all tensors to be on CUDA devices.") - - time_decay = -torch.exp(time_decay.float().contiguous()) - if key.dtype == torch.float16: - time_first = time_first.float() - key = key.float() - value = value.float() - time_first = time_first.contiguous() - key = key.contiguous() - value = value.contiguous() - # The CUDA kernel will fill this tensor. - output = torch.empty_like(key, memory_format=torch.contiguous_format) - if return_state or state is not None: - if state is None: - state = torch.zeros( - batch_size, - hidden_size, - 3, - dtype=torch.float32, - device=key.device, - memory_format=torch.contiguous_format, - ) - state[:, :, 2] -= 1e38 - else: - state = torch.cat([s.unsqueeze(2) for s in state], dim=2).contiguous() - if key.dtype == torch.bfloat16: - forward_func = mamba_cuda_kernel.forward_with_state_bf16 - else: - forward_func = mamba_cuda_kernel.forward_with_state - forward_func(time_decay, time_first, key, value, output, state) - else: - forward_func = mamba_cuda_kernel.forward_bf16 if key.dtype == torch.bfloat16 else mamba_cuda_kernel.forward - forward_func(time_decay, time_first, key, value, output) - - ctx.save_for_backward(time_decay, time_first, key, value, output) - - if state is not None: - state = [s.squeeze(2) for s in torch.chunk(state, 3, dim=2)] - - return output.to(ctx.input_dtype), state - - @staticmethod - # g stands for grad - def backward(ctx, g_output, g_state=None): - input_dtype = ctx.input_dtype - - time_decay, time_first, key, value, output = ctx.saved_tensors - # The CUDA kernel will fill those tensors. - g_time_decay = torch.empty_like( - time_decay, - memory_format=torch.contiguous_format, - dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float32, - ) - g_time_first = torch.empty_like(time_first, memory_format=torch.contiguous_format) - g_key = torch.empty_like(key, memory_format=torch.contiguous_format) - g_value = torch.empty_like(value, memory_format=torch.contiguous_format) - - if input_dtype == torch.float16: - g_output = g_output.float() - backward_func = mamba_cuda_kernel.backward_bf16 if input_dtype == torch.bfloat16 else mamba_cuda_kernel.backward - backward_func( - time_decay, - time_first, - key, - value, - output, - g_output.contiguous(), - g_time_decay, - g_time_first, - g_key, - g_value, - ) - - return ( - g_time_decay.to(input_dtype), - g_time_first.to(input_dtype), - g_key.to(input_dtype), - g_value.to(input_dtype), - None, - None, - ) - - -def mamba_linear_attention_cpu(time_decay, time_first, key, value, state=None, return_state=False): - # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed - # within a torch.no_grad. - _, seq_length, _ = key.size() - output = torch.zeros_like(key) - - if state is None: - num_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - den_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - max_state = torch.zeros_like(key[:, 0], dtype=torch.float32) - 1e38 - else: - num_state, den_state, max_state = state - # For numerical stability - # real_numerator_state = num_state * torch.exp(max_state) - # real_denominator_state = den_state * torch.exp(max_state) - - time_decay = -torch.exp(time_decay) - - for current_index in range(seq_length): - current_key = key[:, current_index].float() - current_value = value[:, current_index] - - # mamba computation at time t - max_for_output = torch.maximum(max_state, current_key + time_first) - e1 = torch.exp(max_state - max_for_output) - e2 = torch.exp(current_key + time_first - max_for_output) - numerator = e1 * num_state + e2 * current_value - denominator = e1 * den_state + e2 - output[:, current_index] = (numerator / denominator).to(output.dtype) - - # Update state for next iteration - max_for_state = torch.maximum(max_state + time_decay, current_key) - e1 = torch.exp(max_state + time_decay - max_for_state) - e2 = torch.exp(current_key - max_for_state) - num_state = e1 * num_state + e2 * current_value - den_state = e1 * den_state + e2 - max_state = max_for_state - - if return_state or state is not None: - state = [num_state, den_state, max_state] - - return output, state - - -def mamba_mixer_forward(time_decay, time_first, key, value, state=None, return_state=False): - no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, key, value]) - # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version - # in this case). - one_token = key.size(1) == 1 - if mamba_cuda_kernel is None or no_cuda or one_token: - return mamba_linear_attention_cpu(time_decay, time_first, key, value, state=state, return_state=return_state) - else: - return MambaMixer.apply(time_decay, time_first, key, value, state, return_state) - - class MambaMixer(nn.Module): def __init__(self, config, layer_idx): @@ -255,11 +106,8 @@ def __init__(self, config, layer_idx): self.expand = config.expand self.d_inner = int(self.expand * self.d_model) self.time_step_rank = math.ceil(self.d_model / 16) if config.time_step_rank == "auto" else config.time_step_rank - # self.use_fast_path = config.use_fast_path self.layer_idx = layer_idx - - self.conv1d = nn.Conv1d( in_channels=self.d_inner, out_channels=self.d_inner, @@ -292,17 +140,14 @@ def __init__(self, config, layer_idx): self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.use_bias) def forward(self, hidden_states: torch.Tensor, inference_params=None): - """ - hidden_states: (B, L, D) - Returns: same shape as hidden_states - """ - _, seqlen, _ = hidden_states.shape - # conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] - + batch_size, seq_len, _ = hidden_states.shape + # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1,2) - hidden_states, z = projected_states.chunk(2, dim=1) + hidden_states, gate = projected_states.chunk(2, dim=1) + # 2. Convolution sequence transformation if inference_params is not None and inference_params.seq_offset > 0: + conv_state = inference_params.conv_states[self.layer_idx] hidden_states = causal_conv1d_update( hidden_states, conv_state, @@ -311,7 +156,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): self.activation, ) else: - conv_state = F.pad(hidden_states, (self.d_conv - seqlen, 0)) + conv_state = nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) hidden_states = causal_conv1d_fn( hidden_states=hidden_states, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), @@ -319,31 +164,33 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): activation=self.activation, ) - # We're careful here about the layout, to avoid extra transposes. - # We want dt to have d as the slowest moving dimension - # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. - x_dbl = self.x_proj(rearrange(hidden_states, "b d l -> (b l) d")) # (bl d) - dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) - dt = self.dt_proj.weight @ dt.t() - dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) - B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() - C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + x_dbl = self.x_proj(hidden_states.transpose(1,2)) + time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) + discrete_time_step = self.dt_proj(time_step) + # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1,2) + + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + ssm_state = inference_params.ssm_states[self.layer_idx] if inference_params is not None and inference_params.seq_offset > 0: y, _ = selective_scan_update( - ssm_state, hidden_states, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ssm_state, hidden_states, discrete_time_step, self.negA, B, C, self.D, z=gate, dt_bias=self.dt_proj.bias, dt_softplus=True ) else: y, last_state = selective_scan_fn( - hidden_states, dt, self.negA, B, C, self.D.float(),z=z,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=True, + hidden_states, discrete_time_step, self.negA, B, C, self.D.float(),z=gate,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=True, ) - y = rearrange(y, "b d l -> b l d") - attn_outputs = self.out_proj(y) - return attn_outputs, conv_state, last_state + # 4. Final linear projection + attn_outputs = self.out_proj(y.transpose(1,2)) + return attn_outputs, conv_state, ssm_state class MambaCache: def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torch.float32, device=None): self.seqlen_offset = 0 + d_model = config.hidden_size d_state = config.state_size expand = config.expand @@ -699,6 +546,7 @@ def forward( else: hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params) inference_params.ssm_states[idx].copy_(ssm_state) + # TODO maybe for torch.compile + graph do things here # inference_params.conv_states[idx].copy_(conv_state) if output_hidden_states: diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 39bf0b138399ed..b1fe0ab3552163 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -421,15 +421,16 @@ def test_model_from_pretrained(self): self.assertIsNotNone(model) -class MAMBAIntegrationTests(unittest.TestCase): +class MambaIntegrationTests(unittest.TestCase): def setUp(self): self.model_id = "state-spaces/mamba-2.8b" self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) def test_simple_generate(self): - from transformers import MambaConfig, MambaForCausalLM, AutoTokenizer import torch + from transformers import AutoTokenizer, MambaForCausalLM + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") tokenizer.pad_token = tokenizer.eos_token From 3cc06e5bf6f72cb8f3f037f633907f75b3e9b159 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Feb 2024 16:04:43 +0900 Subject: [PATCH 025/116] style --- src/transformers/__init__.py | 18 +-- src/transformers/models/auto/modeling_auto.py | 8 +- .../models/mamba/configuration_mamba.py | 3 +- .../models/mamba/modeling_mamba.py | 127 +++++++++++------- tests/models/mamba/test_modeling_mamba.py | 17 ++- 5 files changed, 107 insertions(+), 66 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index fc5f655da62b79..e24a2659f2654d 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -570,6 +570,7 @@ "LxmertTokenizer", ], "models.m2m_100": ["M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP", "M2M100Config"], + "models.mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig"], "models.marian": ["MarianConfig"], "models.markuplm": [ "MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP", @@ -747,7 +748,6 @@ "RoFormerTokenizer", ], "models.rwkv": ["RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP", "RwkvConfig"], - "models.mamba": ["MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP", "MambaConfig"], "models.sam": [ "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP", "SamConfig", @@ -2552,6 +2552,14 @@ "M2M100PreTrainedModel", ] ) + _import_structure["models.mamba"].extend( + [ + "MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST", + "MambaForCausalLM", + "MambaModel", + "MambaPreTrainedModel", + ] + ) _import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"]) _import_structure["models.markuplm"].extend( [ @@ -3131,14 +3139,6 @@ "RwkvPreTrainedModel", ] ) - _import_structure["models.mamba"].extend( - [ - "MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST", - "MambaForCausalLM", - "MambaModel", - "MambaPreTrainedModel", - ] - ) _import_structure["models.sam"].extend( [ "SAM_PRETRAINED_MODEL_ARCHIVE_LIST", diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 59d2332f597c4c..14dcad6ded1e76 100755 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -133,6 +133,7 @@ ("luke", "LukeModel"), ("lxmert", "LxmertModel"), ("m2m_100", "M2M100Model"), + ("mamba", "MambaModel"), ("marian", "MarianModel"), ("markuplm", "MarkupLMModel"), ("mask2former", "Mask2FormerModel"), @@ -188,7 +189,6 @@ ("roc_bert", "RoCBertModel"), ("roformer", "RoFormerModel"), ("rwkv", "RwkvModel"), - ("mamba", "MambaModel"), ("sam", "SamModel"), ("seamless_m4t", "SeamlessM4TModel"), ("seamless_m4t_v2", "SeamlessM4Tv2Model"), @@ -282,6 +282,7 @@ ("longformer", "LongformerForMaskedLM"), ("luke", "LukeForMaskedLM"), ("lxmert", "LxmertForPreTraining"), + ("mamba", "MambaForCausalLM"), ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForPreTraining"), ("mobilebert", "MobileBertForPreTraining"), @@ -297,7 +298,6 @@ ("roberta-prelayernorm", "RobertaPreLayerNormForMaskedLM"), ("roc_bert", "RoCBertForPreTraining"), ("rwkv", "RwkvForCausalLM"), - ("mamba", "MambaForCausalLM"), ("splinter", "SplinterForPreTraining"), ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), @@ -364,6 +364,7 @@ ("longt5", "LongT5ForConditionalGeneration"), ("luke", "LukeForMaskedLM"), ("m2m_100", "M2M100ForConditionalGeneration"), + ("mamba", "MambaForCausalLM"), ("marian", "MarianMTModel"), ("mega", "MegaForMaskedLM"), ("megatron-bert", "MegatronBertForCausalLM"), @@ -387,7 +388,6 @@ ("roc_bert", "RoCBertForMaskedLM"), ("roformer", "RoFormerForMaskedLM"), ("rwkv", "RwkvForCausalLM"), - ("mamba", "MambaForCausalLM"), ("speech_to_text", "Speech2TextForConditionalGeneration"), ("squeezebert", "SqueezeBertForMaskedLM"), ("switch_transformers", "SwitchTransformersForConditionalGeneration"), @@ -436,6 +436,7 @@ ("gpt_neox_japanese", "GPTNeoXJapaneseForCausalLM"), ("gptj", "GPTJForCausalLM"), ("llama", "LlamaForCausalLM"), + ("mamba", "MambaForCausalLM"), ("marian", "MarianForCausalLM"), ("mbart", "MBartForCausalLM"), ("mega", "MegaForCausalLM"), @@ -462,7 +463,6 @@ ("roc_bert", "RoCBertForCausalLM"), ("roformer", "RoFormerForCausalLM"), ("rwkv", "RwkvForCausalLM"), - ("mamba", "MambaForCausalLM"), ("speech_to_text_2", "Speech2Text2ForCausalLM"), ("transfo-xl", "TransfoXLLMHeadModel"), ("trocr", "TrOCRForCausalLM"), diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 79b3e27568a62d..b5534865b9c7d5 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -27,7 +27,6 @@ } - class MambaConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA @@ -123,7 +122,7 @@ def __init__( self.pad_token_id = pad_token_id self.use_bias = use_bias self.use_conv_bias = use_conv_bias - self.hidden_act=hidden_act + self.hidden_act = hidden_act self.initializer_range = initializer_range super().__init__( diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index bcaf453ae7d3b9..dca87bc66e57de 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -24,19 +24,6 @@ from torch import nn from torch.nn import CrossEntropyLoss - -is_causal_conv_1d_available = None -if is_causal_conv_1d_available(): - from causal_conv1d import causal_conv1d_update, causal_conv1d_fn -else : - causal_conv1d_update, causal_conv1d_fn = None, None - -is_kernel_compiled = None -if is_kernel_compiled(): - from causal_conv1d import causal_conv1d_update, causal_conv1d_fn -else : - selective_scan_update, selective_scan_fn = None, None - from ...activations import ACT2FN from ...modeling_utils import PreTrainedModel from ...utils import ( @@ -49,6 +36,18 @@ from .configuration_mamba import MambaConfig +is_causal_conv_1d_available = None +if is_causal_conv_1d_available(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + causal_conv1d_update, causal_conv1d_fn = None, None + +is_kernel_compiled = None +if is_kernel_compiled(): + from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +else: + selective_scan_update, selective_scan_fn = None, None + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m" @@ -60,9 +59,9 @@ ] - mamba_cuda_kernel = None + # Copied from transformers.models.mamba.modeling_mamba.load_mamba_cuda_kernel with mamba->MAMBA,mamba->mamba def load_mamba_cuda_kernel(context_length): from torch.utils.cpp_extension import load as load_kernel @@ -97,7 +96,6 @@ def load_mamba_cuda_kernel(context_length): class MambaMixer(nn.Module): - def __init__(self, config, layer_idx): super().__init__() self.d_model = config.hidden_size @@ -105,7 +103,9 @@ def __init__(self, config, layer_idx): self.d_conv = config.conv_kernel self.expand = config.expand self.d_inner = int(self.expand * self.d_model) - self.time_step_rank = math.ceil(self.d_model / 16) if config.time_step_rank == "auto" else config.time_step_rank + self.time_step_rank = ( + math.ceil(self.d_model / 16) if config.time_step_rank == "auto" else config.time_step_rank + ) self.layer_idx = layer_idx self.conv1d = nn.Conv1d( @@ -129,7 +129,7 @@ def __init__(self, config, layer_idx): # S4D real initialization. These are not discretized! # THe core is to load them, compute the discrete states, then write the updates state. # Keeps the memory bounded - A = torch.arange(1, self.d_state + 1, dtype=torch.float32)[None,:].expand(self.d_inner, -1).contiguous() + A = torch.arange(1, self.d_state + 1, dtype=torch.float32)[None, :].expand(self.d_inner, -1).contiguous() A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True @@ -142,7 +142,7 @@ def __init__(self, config, layer_idx): def forward(self, hidden_states: torch.Tensor, inference_params=None): batch_size, seq_len, _ = hidden_states.shape # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states).transpose(1,2) + projected_states = self.in_proj(hidden_states).transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation @@ -166,29 +166,48 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - x_dbl = self.x_proj(hidden_states.transpose(1,2)) + x_dbl = self.x_proj(hidden_states.transpose(1, 2)) time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) discrete_time_step = self.dt_proj(time_step) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) - discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1,2) + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] if inference_params is not None and inference_params.seq_offset > 0: y, _ = selective_scan_update( - ssm_state, hidden_states, discrete_time_step, self.negA, B, C, self.D, z=gate, dt_bias=self.dt_proj.bias, dt_softplus=True + ssm_state, + hidden_states, + discrete_time_step, + self.negA, + B, + C, + self.D, + z=gate, + dt_bias=self.dt_proj.bias, + dt_softplus=True, ) else: y, last_state = selective_scan_fn( - hidden_states, discrete_time_step, self.negA, B, C, self.D.float(),z=gate,delta_bias=self.dt_proj.bias.float(),delta_softplus=True,return_last_state=True, + hidden_states, + discrete_time_step, + self.negA, + B, + C, + self.D.float(), + z=gate, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=True, ) # 4. Final linear projection - attn_outputs = self.out_proj(y.transpose(1,2)) + attn_outputs = self.out_proj(y.transpose(1, 2)) return attn_outputs, conv_state, ssm_state + class MambaCache: - def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torch.float32, device=None): + def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torch.float32, device=None): self.seqlen_offset = 0 d_model = config.hidden_size @@ -196,12 +215,17 @@ def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torc expand = config.expand d_conv = config.conv_kernel - self.conv_states = { i: torch.zeros(batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype) for i in range(config.num_hidden_layers)} - self.ssm_states = { i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype)for i in range(config.num_hidden_layers)} + self.conv_states = { + i: torch.zeros(batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype) + for i in range(config.num_hidden_layers) + } class MambaMixerSlow(MambaMixer): - def forward(self, hidden_states, inference_params=None): """ @@ -220,31 +244,33 @@ def forward(self, hidden_states, inference_params=None): batch_size, seq_len, _ = hidden_states.shape # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states).transpose(1,2) + projected_states = self.in_proj(hidden_states).transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation if inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) - conv_state[:, :, -1].copy_(hidden_states[:,:,0]) + conv_state[:, :, -1].copy_(hidden_states[:, :, 0]) - hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + self.conv1d.bias) + hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + self.conv1d.bias) hidden_states = hidden_states.unsqueeze(-1) else: - inference_params.conv_states[self.layer_idx].copy_(nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0))) + inference_params.conv_states[self.layer_idx].copy_( + nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) + ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - x_dbl = self.x_proj(hidden_states.transpose(1,2)) + x_dbl = self.x_proj(hidden_states.transpose(1, 2)) time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) discrete_time_step = self.dt_proj(time_step) A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) - discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1,2) + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch_size, d, l, 1] X [1, d, 1, n] -> [batch_size, d, l, n] dA = torch.exp(discrete_time_step[:, :, :, None] * A[None, :, None, :]) # [batch_size, d, l, 1] [b, d, l, 1] -> [batch_size, d, l, 1] X [batch_size, 1, l, n] -> [batch_size, d, l, n] @@ -256,15 +282,16 @@ def forward(self, hidden_states, inference_params=None): for i in range(seq_len): ssm_state.copy_(ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :]) # [b, d, n] X [b, n] -> [b, d] - y = torch.matmul(ssm_state, C[:,i,:].unsqueeze(-1)) - ys.append(y[:,:,0]) + y = torch.matmul(ssm_state, C[:, i, :].unsqueeze(-1)) + ys.append(y[:, :, 0]) y = torch.stack(ys, dim=-1) # shape (b, l, d) - y = y + (hidden_states * self.D.to(hidden_states.dtype)[None,:,None]) - y = y * self.act(gate) # (B D) + y = y + (hidden_states * self.D.to(hidden_states.dtype)[None, :, None]) + y = y * self.act(gate) # (B D) # 4. Final linear projection - attn_outputs = self.out_proj(y.transpose(1,2)) + attn_outputs = self.out_proj(y.transpose(1, 2)) return attn_outputs, None, ssm_state + class MambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -528,7 +555,7 @@ def forward( inputs_embeds = self.embeddings(input_ids) if use_cache and inference_params is None: - inference_params = MambaCache(self.config, inputs_embeds.size(0), device=inputs_embeds.device) + inference_params = MambaCache(self.config, inputs_embeds.size(0), device=inputs_embeds.device) if self.gradient_checkpointing and self.training: if use_cache: @@ -542,7 +569,9 @@ def forward( all_hidden_states = () if output_hidden_states else None for idx, layer in enumerate(self.layers): if self.gradient_checkpointing and self.training: - hidden_states, conv_state, ssm_state = self._gradient_checkpointing_func(layer.__call__, hidden_states, inference_params) + hidden_states, conv_state, ssm_state = self._gradient_checkpointing_func( + layer.__call__, hidden_states, inference_params + ) else: hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params) inference_params.ssm_states[idx].copy_(ssm_state) @@ -562,7 +591,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(hidden_states for hidden_states in [hidden_states, inference_params, all_hidden_states, all_last_states] if hidden_states is not None) + return tuple( + hidden_states + for hidden_states in [hidden_states, inference_params, all_hidden_states, all_last_states] + if hidden_states is not None + ) return MambaOutput( last_hidden_state=hidden_states, @@ -603,17 +636,19 @@ def get_input_embeddings(self): def set_input_embeddings(self, new_embeddings): return self.backbone.set_input_embeddings(new_embeddings) - def _update_model_kwargs_for_generation(self,outputs: ModelOutput, + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, ) -> Dict[str, Any]: - model_kwargs["inference_params"] = outputs["inference_params"] return model_kwargs - - def prepare_inputs_for_generation(self, input_ids, inference_params=None, inputs_embeds=None, attention_mask=None, **kwargs): + def prepare_inputs_for_generation( + self, input_ids, inference_params=None, inputs_embeds=None, attention_mask=None, **kwargs + ): # only last token for inputs_ids if the state is passed along. if inference_params is not None: input_ids = input_ids[:, -1].unsqueeze(-1) @@ -637,7 +672,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - inference_params: Optional[torch.FloatTensor] = None, + inference_params: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index b1fe0ab3552163..5fb9f33a13ab74 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -434,20 +434,27 @@ def test_simple_generate(self): tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") tokenizer.pad_token = tokenizer.eos_token - model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float16) + model = MambaForCausalLM.from_pretrained( + "state-spaces/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float16 + ) model.to(torch_device) model.config.use_cache = True input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(torch_device) - logits = model(input_ids = input_ids) + logits = model(input_ids=input_ids) - EXPECTED_LOGITS = torch.tensor([ -6.7070, -24.7656, -6.4766, -6.0078, -9.7812, -13.0703, -11.4688, -10.6562, -9.3359, -9.4766, -9.1719, -7.9102, -13.0469, -8.7266, -8.4297, -8.4766, -9.1094, -11.5234, -11.1250, -11.7812, -12.1562, -12.8359, -12.1797, -13.4062, -13.6406, -13.4141, -13.6562, -9.2344, -7.9805, -7.2188, -9.9219, -9.1719, -7.8438, -9.1250, -10.1094, -10.2344, -10.2266, -9.7578, -11.0000, -10.6406], device='cuda:0',dtype=torch.float16) # fmt: skip + EXPECTED_LOGITS = torch.tensor([ -6.7070, -24.7656, -6.4766, -6.0078, -9.7812, -13.0703, -11.4688, -10.6562, -9.3359, -9.4766, -9.1719, -7.9102, -13.0469, -8.7266, -8.4297, -8.4766, -9.1094, -11.5234, -11.1250, -11.7812, -12.1562, -12.8359, -12.1797, -13.4062, -13.6406, -13.4141, -13.6562, -9.2344, -7.9805, -7.2188, -9.9219, -9.1719, -7.8438, -9.1250, -10.1094, -10.2344, -10.2266, -9.7578, -11.0000, -10.6406], device='cuda:0',dtype=torch.float16) # fmt: skip torch.testing.assert_allclose(logits, EXPECTED_LOGITS) out = model.generate(input_ids, max_new_tokens=10) - output_sentence = tokenizer.decode(out[0,:]) - self.assertEqual(output_sentence, ["Hey how are you doing?\n\nI'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm"]) + output_sentence = tokenizer.decode(out[0, :]) + self.assertEqual( + output_sentence, + [ + "Hey how are you doing?\n\nI'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm" + ], + ) def test_simple_generate_bf16(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" From 5a5324c2bfa7dc8d3b5701ee39d9e6bc9c21fdd0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 16 Feb 2024 16:08:55 +0900 Subject: [PATCH 026/116] oups --- src/transformers/models/auto/configuration_auto.py | 6 +++--- src/transformers/models/mamba/modeling_mamba.py | 6 ++---- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 40400aeb6a2a16..2aa5c20d0c95b8 100755 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -135,6 +135,7 @@ ("luke", "LukeConfig"), ("lxmert", "LxmertConfig"), ("m2m_100", "M2M100Config"), + ("mamba", "MambaConfig"), ("marian", "MarianConfig"), ("markuplm", "MarkupLMConfig"), ("mask2former", "Mask2FormerConfig"), @@ -196,7 +197,6 @@ ("roc_bert", "RoCBertConfig"), ("roformer", "RoFormerConfig"), ("rwkv", "RwkvConfig"), - ("mamba", "MambaConfig"), ("sam", "SamConfig"), ("seamless_m4t", "SeamlessM4TConfig"), ("seamless_m4t_v2", "SeamlessM4Tv2Config"), @@ -367,6 +367,7 @@ ("luke", "LUKE_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("lxmert", "LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("m2m_100", "M2M_100_PRETRAINED_CONFIG_ARCHIVE_MAP"), + ("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("markuplm", "MARKUPLM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("mask2former", "MASK2FORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("maskformer", "MASKFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -421,7 +422,6 @@ ("roc_bert", "ROC_BERT_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("roformer", "ROFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("rwkv", "RWKV_PRETRAINED_CONFIG_ARCHIVE_MAP"), - ("mamba", "MAMBA_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("sam", "SAM_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("seamless_m4t", "SEAMLESS_M4T_PRETRAINED_CONFIG_ARCHIVE_MAP"), ("seamless_m4t_v2", "SEAMLESS_M4T_V2_PRETRAINED_CONFIG_ARCHIVE_MAP"), @@ -603,6 +603,7 @@ ("lxmert", "LXMERT"), ("m2m_100", "M2M100"), ("madlad-400", "MADLAD-400"), + ("mamba", "Mamba"), ("marian", "Marian"), ("markuplm", "MarkupLM"), ("mask2former", "Mask2Former"), @@ -671,7 +672,6 @@ ("roc_bert", "RoCBert"), ("roformer", "RoFormer"), ("rwkv", "RWKV"), - ("mamba", "Mamba"), ("sam", "SAM"), ("seamless_m4t", "SeamlessM4T"), ("seamless_m4t_v2", "SeamlessM4Tv2"), diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index dca87bc66e57de..a41b523f42e08f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -36,14 +36,12 @@ from .configuration_mamba import MambaConfig -is_causal_conv_1d_available = None -if is_causal_conv_1d_available(): +if False and is_causal_conv_1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: causal_conv1d_update, causal_conv1d_fn = None, None -is_kernel_compiled = None -if is_kernel_compiled(): +if False and is_kernel_compiled(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: selective_scan_update, selective_scan_fn = None, None From 81303f4ddac28096da3a57ea483741891b004e29 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:15:38 +0100 Subject: [PATCH 027/116] dont install kernels, invite users to install the required kernels --- src/transformers/kernels/mamba/Makefile | 28 + .../mamba/selective_scan/reverse_scan.cuh | 401 ------------- .../mamba/selective_scan/selective_scan.cpp | 497 ---------------- .../mamba/selective_scan/selective_scan.h | 101 ---- .../selective_scan_bwd_bf16_complex.cu | 9 - .../selective_scan_bwd_bf16_real.cu | 9 - .../selective_scan_bwd_fp16_complex.cu | 9 - .../selective_scan_bwd_fp16_real.cu | 9 - .../selective_scan_bwd_fp32_complex.cu | 9 - .../selective_scan_bwd_fp32_real.cu | 9 - .../selective_scan_bwd_kernel.cuh | 531 ------------------ .../selective_scan/selective_scan_common.h | 221 -------- .../selective_scan/selective_scan_fwd_bf16.cu | 10 - .../selective_scan/selective_scan_fwd_fp16.cu | 10 - .../selective_scan/selective_scan_fwd_fp32.cu | 10 - .../selective_scan_fwd_kernel.cuh | 345 ------------ .../mamba/selective_scan/static_switch.h | 25 - .../selective_scan/uninitialized_copy.cuh | 69 --- .../models/mamba/configuration_mamba.py | 4 + .../models/mamba/modeling_mamba.py | 96 +--- 20 files changed, 57 insertions(+), 2345 deletions(-) create mode 100644 src/transformers/kernels/mamba/Makefile delete mode 100644 src/transformers/kernels/mamba/selective_scan/reverse_scan.cuh delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan.cpp delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan.h delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_complex.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_real.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_complex.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_real.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_complex.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_real.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_kernel.cuh delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_common.h delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_bf16.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp16.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp32.cu delete mode 100644 src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_kernel.cuh delete mode 100644 src/transformers/kernels/mamba/selective_scan/static_switch.h delete mode 100644 src/transformers/kernels/mamba/selective_scan/uninitialized_copy.cuh diff --git a/src/transformers/kernels/mamba/Makefile b/src/transformers/kernels/mamba/Makefile new file mode 100644 index 00000000000000..f4dec8688abc3e --- /dev/null +++ b/src/transformers/kernels/mamba/Makefile @@ -0,0 +1,28 @@ +selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137 + +causal-conv1d: + rm -rf causal-conv1d + git clone https://github.com/Dao-AILab/causal-conv1d.git + +build-causal-conv1d: causal-conv1d + cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag + cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build + +install-causal-conv1d: build-causal-conv1d + pip uninstall causal-conv1d -y || true + cd causal-conv1d/ && pip install . + +# selective-scan dependends on causal-conv1d +selective-scan: + rm -rf mamba + git clone https://github.com/state-spaces/mamba.git mamba + +build-selective-scan: selective-scan + cd mamba/ && git fetch && git checkout $(selective_scan_commit) + cd mamba && python setup.py build + +install-selective-scan: install-causal-conv1d build-selective-scan + pip uninstall selective-scan-cuda -y || true + cd mamba && pip install . + +build-all: build-causal-conv1d build-selective-scan \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/reverse_scan.cuh b/src/transformers/kernels/mamba/selective_scan/reverse_scan.cuh deleted file mode 100644 index d7e93174bb391d..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/reverse_scan.cuh +++ /dev/null @@ -1,401 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include - -#include -#include -#include -// #include -#include "uninitialized_copy.cuh" - -/** - * Perform a reverse sequential reduction over \p LENGTH elements of the \p input array. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ReductionOp> -__device__ __forceinline__ T ThreadReverseReduce(const T (&input)[LENGTH], ReductionOp reduction_op) { - static_assert(LENGTH > 0); - T retval = input[LENGTH - 1]; - #pragma unroll - for (int i = LENGTH - 2; i >= 0; --i) { retval = reduction_op(retval, input[i]); } - return retval; -} - -/** - * Perform a sequential inclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ScanOp> -__device__ __forceinline__ T ThreadReverseScanInclusive( - const T (&input)[LENGTH], - T (&output)[LENGTH], - ScanOp scan_op, - const T postfix) -{ - T inclusive = postfix; - #pragma unroll - for (int i = LENGTH - 1; i >= 0; --i) { - inclusive = scan_op(inclusive, input[i]); - output[i] = inclusive; - } -} - -/** - * Perform a sequential exclusive postfix reverse scan over the statically-sized \p input array, seeded with the specified \p postfix. The aggregate is returned. - */ -template < - int LENGTH, - typename T, - typename ScanOp> -__device__ __forceinline__ T ThreadReverseScanExclusive( - const T (&input)[LENGTH], - T (&output)[LENGTH], - ScanOp scan_op, - const T postfix) -{ - // Careful, output maybe be aliased to input - T exclusive = postfix; - T inclusive; - #pragma unroll - for (int i = LENGTH - 1; i >= 0; --i) { - inclusive = scan_op(exclusive, input[i]); - output[i] = exclusive; - exclusive = inclusive; - } - return inclusive; -} - - -/** - * \brief WarpReverseScan provides SHFL-based variants of parallel postfix scan of items partitioned across a CUDA thread warp. - * - * LOGICAL_WARP_THREADS must be a power-of-two - */ -template < - typename T, ///< Data type being scanned - int LOGICAL_WARP_THREADS ///< Number of threads per logical warp - > -struct WarpReverseScan { - //--------------------------------------------------------------------- - // Constants and type definitions - //--------------------------------------------------------------------- - - /// Whether the logical warp size and the PTX warp size coincide - static constexpr bool IS_ARCH_WARP = (LOGICAL_WARP_THREADS == CUB_WARP_THREADS(0)); - /// The number of warp scan steps - static constexpr int STEPS = cub::Log2::VALUE; - static_assert(LOGICAL_WARP_THREADS == 1 << STEPS); - - - //--------------------------------------------------------------------- - // Thread fields - //--------------------------------------------------------------------- - - /// Lane index in logical warp - unsigned int lane_id; - - /// Logical warp index in 32-thread physical warp - unsigned int warp_id; - - /// 32-thread physical warp member mask of logical warp - unsigned int member_mask; - - //--------------------------------------------------------------------- - // Construction - //--------------------------------------------------------------------- - - /// Constructor - explicit __device__ __forceinline__ - WarpReverseScan() - : lane_id(cub::LaneId()) - , warp_id(IS_ARCH_WARP ? 0 : (lane_id / LOGICAL_WARP_THREADS)) - , member_mask(cub::WarpMask(warp_id)) - { - if (!IS_ARCH_WARP) { - lane_id = lane_id % LOGICAL_WARP_THREADS; - } - } - - - /// Broadcast - __device__ __forceinline__ T Broadcast( - T input, ///< [in] The value to broadcast - int src_lane) ///< [in] Which warp lane is to do the broadcasting - { - return cub::ShuffleIndex(input, src_lane, member_mask); - } - - - /// Inclusive scan - template - __device__ __forceinline__ void InclusiveReverseScan( - T input, ///< [in] Calling thread's input item. - T &inclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. - ScanOpT scan_op) ///< [in] Binary scan operator - { - inclusive_output = input; - #pragma unroll - for (int STEP = 0; STEP < STEPS; STEP++) { - int offset = 1 << STEP; - T temp = cub::ShuffleDown( - inclusive_output, offset, LOGICAL_WARP_THREADS - 1, member_mask - ); - // Perform scan op if from a valid peer - inclusive_output = static_cast(lane_id) >= LOGICAL_WARP_THREADS - offset - ? inclusive_output : scan_op(temp, inclusive_output); - } - } - - /// Exclusive scan - // Get exclusive from inclusive - template - __device__ __forceinline__ void ExclusiveReverseScan( - T input, ///< [in] Calling thread's input item. - T &exclusive_output, ///< [out] Calling thread's output item. May be aliased with \p input. - ScanOpT scan_op, ///< [in] Binary scan operator - T &warp_aggregate) ///< [out] Warp-wide aggregate reduction of input items. - { - T inclusive_output; - InclusiveReverseScan(input, inclusive_output, scan_op); - warp_aggregate = cub::ShuffleIndex(inclusive_output, 0, member_mask); - // initial value unknown - exclusive_output = cub::ShuffleDown( - inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask - ); - } - - /** - * \brief Computes both inclusive and exclusive reverse scans using the specified binary scan functor across the calling warp. Because no initial value is supplied, the \p exclusive_output computed for the last warp-lane is undefined. - */ - template - __device__ __forceinline__ void ReverseScan( - T input, ///< [in] Calling thread's input item. - T &inclusive_output, ///< [out] Calling thread's inclusive-scan output item. - T &exclusive_output, ///< [out] Calling thread's exclusive-scan output item. - ScanOpT scan_op) ///< [in] Binary scan operator - { - InclusiveReverseScan(input, inclusive_output, scan_op); - // initial value unknown - exclusive_output = cub::ShuffleDown( - inclusive_output, 1, LOGICAL_WARP_THREADS - 1, member_mask - ); - } - -}; - -/** - * \brief BlockReverseScan provides variants of raking-based parallel postfix scan across a CUDA thread block. - */ -template < - typename T, ///< Data type being scanned - int BLOCK_DIM_X, ///< The thread block length in threads along the X dimension - bool MEMOIZE=false ///< Whether or not to buffer outer raking scan partials to incur fewer shared memory reads at the expense of higher register pressure - > -struct BlockReverseScan { - //--------------------------------------------------------------------- - // Types and constants - //--------------------------------------------------------------------- - - /// Constants - /// The thread block size in threads - static constexpr int BLOCK_THREADS = BLOCK_DIM_X; - - /// Layout type for padded thread block raking grid - using BlockRakingLayout = cub::BlockRakingLayout; - // The number of reduction elements is not a multiple of the number of raking threads for now - static_assert(BlockRakingLayout::UNGUARDED); - - /// Number of raking threads - static constexpr int RAKING_THREADS = BlockRakingLayout::RAKING_THREADS; - /// Number of raking elements per warp synchronous raking thread - static constexpr int SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH; - /// Cooperative work can be entirely warp synchronous - static constexpr bool WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)); - - /// WarpReverseScan utility type - using WarpReverseScan = WarpReverseScan; - - /// Shared memory storage layout type - struct _TempStorage { - typename BlockRakingLayout::TempStorage raking_grid; ///< Padded thread block raking grid - }; - - - /// Alias wrapper allowing storage to be unioned - struct TempStorage : cub::Uninitialized<_TempStorage> {}; - - - //--------------------------------------------------------------------- - // Per-thread fields - //--------------------------------------------------------------------- - - // Thread fields - _TempStorage &temp_storage; - unsigned int linear_tid; - T cached_segment[SEGMENT_LENGTH]; - - - //--------------------------------------------------------------------- - // Utility methods - //--------------------------------------------------------------------- - - /// Performs upsweep raking reduction, returning the aggregate - template - __device__ __forceinline__ T Upsweep(ScanOp scan_op) { - T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); - // Read data into registers - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } - T raking_partial = cached_segment[SEGMENT_LENGTH - 1]; - #pragma unroll - for (int i = SEGMENT_LENGTH - 2; i >= 0; --i) { - raking_partial = scan_op(raking_partial, cached_segment[i]); - } - return raking_partial; - } - - - /// Performs exclusive downsweep raking scan - template - __device__ __forceinline__ void ExclusiveDownsweep( - ScanOp scan_op, - T raking_partial) - { - T *smem_raking_ptr = BlockRakingLayout::RakingPtr(temp_storage.raking_grid, linear_tid); - // Read data back into registers - if (!MEMOIZE) { - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { cached_segment[i] = smem_raking_ptr[i]; } - } - ThreadReverseScanExclusive(cached_segment, cached_segment, scan_op, raking_partial); - // Write data back to smem - #pragma unroll - for (int i = 0; i < SEGMENT_LENGTH; ++i) { smem_raking_ptr[i] = cached_segment[i]; } - } - - - //--------------------------------------------------------------------- - // Constructors - //--------------------------------------------------------------------- - - /// Constructor - __device__ __forceinline__ BlockReverseScan( - TempStorage &temp_storage) - : - temp_storage(temp_storage.Alias()), - linear_tid(cub::RowMajorTid(BLOCK_DIM_X, 1, 1)) - {} - - - /// Computes an exclusive thread block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes one input element. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. - template < - typename ScanOp, - typename BlockPostfixCallbackOp> - __device__ __forceinline__ void ExclusiveReverseScan( - T input, ///< [in] Calling thread's input item - T &exclusive_output, ///< [out] Calling thread's output item (may be aliased to \p input) - ScanOp scan_op, ///< [in] Binary scan operator - BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a thread block-wide postfix to be applied to all inputs. - { - if (WARP_SYNCHRONOUS) { - // Short-circuit directly to warp-synchronous scan - T block_aggregate; - WarpReverseScan warp_scan; - warp_scan.ExclusiveReverseScan(input, exclusive_output, scan_op, block_aggregate); - // Obtain warp-wide postfix in lane0, then broadcast to other lanes - T block_postfix = block_postfix_callback_op(block_aggregate); - block_postfix = warp_scan.Broadcast(block_postfix, 0); - exclusive_output = linear_tid == BLOCK_THREADS - 1 ? block_postfix : scan_op(block_postfix, exclusive_output); - } else { - // Place thread partial into shared memory raking grid - T *placement_ptr = BlockRakingLayout::PlacementPtr(temp_storage.raking_grid, linear_tid); - detail::uninitialized_copy(placement_ptr, input); - cub::CTA_SYNC(); - // Reduce parallelism down to just raking threads - if (linear_tid < RAKING_THREADS) { - WarpReverseScan warp_scan; - // Raking upsweep reduction across shared partials - T upsweep_partial = Upsweep(scan_op); - // Warp-synchronous scan - T exclusive_partial, block_aggregate; - warp_scan.ExclusiveReverseScan(upsweep_partial, exclusive_partial, scan_op, block_aggregate); - // Obtain block-wide postfix in lane0, then broadcast to other lanes - T block_postfix = block_postfix_callback_op(block_aggregate); - block_postfix = warp_scan.Broadcast(block_postfix, 0); - // Update postfix with warpscan exclusive partial - T downsweep_postfix = linear_tid == RAKING_THREADS - 1 - ? block_postfix : scan_op(block_postfix, exclusive_partial); - // Exclusive raking downsweep scan - ExclusiveDownsweep(scan_op, downsweep_postfix); - } - cub::CTA_SYNC(); - // Grab thread postfix from shared memory - exclusive_output = *placement_ptr; - - // // Compute warp scan in each warp. - // // The exclusive output from the last lane in each warp is invalid. - // T inclusive_output; - // WarpReverseScan warp_scan; - // warp_scan.ReverseScan(input, inclusive_output, exclusive_output, scan_op); - - // // Compute the warp-wide postfix and block-wide aggregate for each warp. Warp postfix for the last warp is invalid. - // T block_aggregate; - // T warp_postfix = ComputeWarpPostfix(scan_op, inclusive_output, block_aggregate); - - // // Apply warp postfix to our lane's partial - // if (warp_id != 0) { - // exclusive_output = scan_op(warp_postfix, exclusive_output); - // if (lane_id == 0) { exclusive_output = warp_postfix; } - // } - - // // Use the first warp to determine the thread block postfix, returning the result in lane0 - // if (warp_id == 0) { - // T block_postfix = block_postfix_callback_op(block_aggregate); - // if (lane_id == 0) { - // // Share the postfix with all threads - // detail::uninitialized_copy(&temp_storage.block_postfix, - // block_postfix); - - // exclusive_output = block_postfix; // The block postfix is the exclusive output for tid0 - // } - // } - - // cub::CTA_SYNC(); - - // // Incorporate thread block postfix into outputs - // T block_postfix = temp_storage.block_postfix; - // if (linear_tid > 0) { exclusive_output = scan_op(block_postfix, exclusive_output); } - } - } - - - /** - * \brief Computes an inclusive block-wide postfix scan using the specified binary \p scan_op functor. Each thread contributes an array of consecutive input elements. the call-back functor \p block_postfix_callback_op is invoked by the first warp in the block, and the value returned by lane0 in that warp is used as the "seed" value that logically postfixes the thread block's scan inputs. Also provides every thread with the block-wide \p block_aggregate of all inputs. - */ - template < - int ITEMS_PER_THREAD, - typename ScanOp, - typename BlockPostfixCallbackOp> - __device__ __forceinline__ void InclusiveReverseScan( - T (&input)[ITEMS_PER_THREAD], ///< [in] Calling thread's input items - T (&output)[ITEMS_PER_THREAD], ///< [out] Calling thread's output items (may be aliased to \p input) - ScanOp scan_op, ///< [in] Binary scan functor - BlockPostfixCallbackOp &block_postfix_callback_op) ///< [in-out] [warp0 only] Call-back functor for specifying a block-wide postfix to be applied to the logical input sequence. - { - // Reduce consecutive thread items in registers - T thread_postfix = ThreadReverseReduce(input, scan_op); - // Exclusive thread block-scan - ExclusiveReverseScan(thread_postfix, thread_postfix, scan_op, block_postfix_callback_op); - // Inclusive scan in registers with postfix as seed - ThreadReverseScanInclusive(input, output, scan_op, thread_postfix); - } - -}; \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan.cpp b/src/transformers/kernels/mamba/selective_scan/selective_scan.cpp deleted file mode 100644 index cde867cd32d39b..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan.cpp +++ /dev/null @@ -1,497 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#include -#include -#include -#include - -#include "selective_scan.h" - -#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") - -#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ - if (ITYPE == at::ScalarType::Half) { \ - using input_t = at::Half; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::BFloat16) { \ - using input_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (ITYPE == at::ScalarType::Float) { \ - using input_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Half) { \ - using weight_t = at::Half; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::BFloat16) { \ - using weight_t = at::BFloat16; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -#define DISPATCH_WTYPE_FLOAT_AND_COMPLEX(WTYPE, NAME, ...) \ - if (WTYPE == at::ScalarType::Float) { \ - using weight_t = float; \ - __VA_ARGS__(); \ - } else if (WTYPE == at::ScalarType::ComplexFloat) { \ - using weight_t = c10::complex; \ - __VA_ARGS__(); \ - } else { \ - AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \ - } - -template -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); - -template -void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); - -void set_ssm_params_fwd(SSMParamsBase ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, - // device pointers - const at::Tensor u, - const at::Tensor delta, - const at::Tensor A, - const at::Tensor B, - const at::Tensor C, - const at::Tensor out, - const at::Tensor z, - const at::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - bool has_z, - bool delta_softplus) { - - // Reset the parameters - memset(¶ms, 0, sizeof(params)); - - params.batch = batch; - params.dim = dim; - params.seqlen = seqlen; - params.dstate = dstate; - params.n_groups = n_groups; - params.n_chunks = n_chunks; - params.dim_ngroups_ratio = dim / n_groups; - - params.delta_softplus = delta_softplus; - - params.is_variable_B = is_variable_B; - params.is_variable_C = is_variable_C; - - // Set the pointers and strides. - params.u_ptr = u.data_ptr(); - params.delta_ptr = delta.data_ptr(); - params.A_ptr = A.data_ptr(); - params.B_ptr = B.data_ptr(); - params.C_ptr = C.data_ptr(); - params.D_ptr = D_ptr; - params.delta_bias_ptr = delta_bias_ptr; - params.out_ptr = out.data_ptr(); - params.x_ptr = x_ptr; - params.z_ptr = has_z ? z.data_ptr() : nullptr; - params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; - // All stride are in elements, not bytes. - params.A_d_stride = A.stride(0); - params.A_dstate_stride = A.stride(1); - if (!is_variable_B) { - params.B_d_stride = B.stride(0); - } else { - params.B_batch_stride = B.stride(0); - params.B_group_stride = B.stride(1); - } - params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); - if (!is_variable_C) { - params.C_d_stride = C.stride(0); - } else { - params.C_batch_stride = C.stride(0); - params.C_group_stride = C.stride(1); - } - params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); - params.u_batch_stride = u.stride(0); - params.u_d_stride = u.stride(1); - params.delta_batch_stride = delta.stride(0); - params.delta_d_stride = delta.stride(1); - if (has_z) { - params.z_batch_stride = z.stride(0); - params.z_d_stride = z.stride(1); - params.out_z_batch_stride = out_z.stride(0); - params.out_z_d_stride = out_z.stride(1); - } - params.out_batch_stride = out.stride(0); - params.out_d_stride = out.stride(1); -} - -void set_ssm_params_bwd(SSMParamsBwd ¶ms, - // sizes - const size_t batch, - const size_t dim, - const size_t seqlen, - const size_t dstate, - const size_t n_groups, - const size_t n_chunks, - const bool is_variable_B, - const bool is_variable_C, - // device pointers - const at::Tensor u, - const at::Tensor delta, - const at::Tensor A, - const at::Tensor B, - const at::Tensor C, - const at::Tensor z, - const at::Tensor out, - const at::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, - const at::Tensor dout, - const at::Tensor du, - const at::Tensor ddelta, - const at::Tensor dA, - const at::Tensor dB, - const at::Tensor dC, - const at::Tensor dz, - void* dD_ptr, - void* ddelta_bias_ptr, - bool has_z, - bool delta_softplus, - bool recompute_out_z) { - // Pass in "dout" instead of "out", we're not gonna use "out" unless we have z - set_ssm_params_fwd(params, batch, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, has_z ? out : dout, - has_z ? z : dout, - // If not recompute_out_z, pass dout instead of out_z. - // This won't be used by the bwd kernel - recompute_out_z ? out_z : dout, - D_ptr, delta_bias_ptr, x_ptr, has_z, delta_softplus); - if (!recompute_out_z) { params.out_z_ptr = nullptr; } - - // Set the pointers and strides. - params.dout_ptr = dout.data_ptr(); - params.du_ptr = du.data_ptr(); - params.dA_ptr = dA.data_ptr(); - params.dB_ptr = dB.data_ptr(); - params.dC_ptr = dC.data_ptr(); - params.dD_ptr = dD_ptr; - params.ddelta_ptr = ddelta.data_ptr(); - params.ddelta_bias_ptr = ddelta_bias_ptr; - params.dz_ptr = has_z ? dz.data_ptr() : nullptr; - // All stride are in elements, not bytes. - params.dout_batch_stride = dout.stride(0); - params.dout_d_stride = dout.stride(1); - params.dA_d_stride = dA.stride(0); - params.dA_dstate_stride = dA.stride(1); - if (!is_variable_B) { - params.dB_d_stride = dB.stride(0); - } else { - params.dB_batch_stride = dB.stride(0); - params.dB_group_stride = dB.stride(1); - } - params.dB_dstate_stride = !is_variable_B ? dB.stride(1) : dB.stride(2); - if (!is_variable_C) { - params.dC_d_stride = dC.stride(0); - } else { - params.dC_batch_stride = dC.stride(0); - params.dC_group_stride = dC.stride(1); - } - params.dC_dstate_stride = !is_variable_C ? dC.stride(1) : dC.stride(2); - params.du_batch_stride = du.stride(0); - params.du_d_stride = du.stride(1); - params.ddelta_batch_stride = ddelta.stride(0); - params.ddelta_d_stride = ddelta.stride(1); - if (has_z) { - params.dz_batch_stride = dz.stride(0); - params.dz_d_stride = dz.stride(1); - } -} - -std::vector -selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, - const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, - const c10::optional &D_, - const c10::optional &z_, - const c10::optional &delta_bias_, - bool delta_softplus) { - auto input_type = u.scalar_type(); - auto weight_type = A.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); - - const bool is_variable_B = B.dim() >= 3; - const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; - - TORCH_CHECK(delta.scalar_type() == input_type); - TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); - TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); - - TORCH_CHECK(u.is_cuda()); - TORCH_CHECK(delta.is_cuda()); - TORCH_CHECK(A.is_cuda()); - TORCH_CHECK(B.is_cuda()); - TORCH_CHECK(C.is_cuda()); - - TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); - - const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; - - TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); - } - - if (D_.has_value()) { - auto D = D_.value(); - TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); - CHECK_SHAPE(delta_bias, dim); - } - - at::Tensor z, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - out_z = torch::empty_like(z); - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - // at::Tensor out = torch::empty_like(u); - // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout - at::Tensor out = torch::empty_like(delta); - at::Tensor x; - x = torch::empty({batch_size, dim, n_chunks, dstate * 2}, u.options().dtype(weight_type)); - - SSMParamsBase params; - set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, out, z, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x.data_ptr(), - has_z, - delta_softplus); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_fwd", [&] { - selective_scan_fwd_cuda(params, stream); - }); - }); - std::vector result = {out, x}; - if (has_z) { result.push_back(out_z); } - return result; -} - -std::vector -selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, - const at::Tensor &A, const at::Tensor &B, const at::Tensor &C, - const c10::optional &D_, - const c10::optional &z_, - const c10::optional &delta_bias_, - const at::Tensor &dout, - const c10::optional &x_, - const c10::optional &out_, - c10::optional &dz_, - bool delta_softplus, - bool recompute_out_z) { - auto input_type = u.scalar_type(); - auto weight_type = A.scalar_type(); - TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); - TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat); - - const bool is_variable_B = B.dim() >= 3; - const bool is_variable_C = C.dim() >= 3; - const bool is_complex = weight_type == at::ScalarType::ComplexFloat; - - TORCH_CHECK(delta.scalar_type() == input_type); - TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); - TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); - TORCH_CHECK(dout.scalar_type() == input_type); - - TORCH_CHECK(u.is_cuda()); - TORCH_CHECK(delta.is_cuda()); - TORCH_CHECK(A.is_cuda()); - TORCH_CHECK(B.is_cuda()); - TORCH_CHECK(C.is_cuda()); - TORCH_CHECK(dout.is_cuda()); - - TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); - TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); - - const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; - const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; - - TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); - - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); - CHECK_SHAPE(A, dim, dstate); - if (!is_variable_B) { - CHECK_SHAPE(B, dim, dstate); - } else { - CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); - } - if (!is_variable_C) { - CHECK_SHAPE(C, dim, dstate); - } else { - CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); - } - CHECK_SHAPE(dout, batch_size, dim, seqlen); - - if (D_.has_value()) { - auto D = D_.value(); - TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); - CHECK_SHAPE(D, dim); - } - - if (delta_bias_.has_value()) { - auto delta_bias = delta_bias_.value(); - TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); - TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); - CHECK_SHAPE(delta_bias, dim); - } - - at::Tensor z, out, dz, out_z; - const bool has_z = z_.has_value(); - if (has_z) { - z = z_.value(); - TORCH_CHECK(z.scalar_type() == input_type); - TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - - TORCH_CHECK(out_.has_value()); - out = out_.value(); - TORCH_CHECK(out.scalar_type() == input_type); - TORCH_CHECK(out.is_cuda()); - TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1); - CHECK_SHAPE(out, batch_size, dim, seqlen); - - if (dz_.has_value()) { - dz = dz_.value(); - TORCH_CHECK(dz.scalar_type() == input_type); - TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1); - CHECK_SHAPE(dz, batch_size, dim, seqlen); - } else { - dz = torch::empty_like(z); - } - if (recompute_out_z) { - out_z = torch::empty_like(out); - } - } - - const int n_chunks = (seqlen + 2048 - 1) / 2048; - // const int n_chunks = (seqlen + 1024 - 1) / 1024; - if (n_chunks > 1) { TORCH_CHECK(x_.has_value()); } - if (x_.has_value()) { - auto x = x_.value(); - TORCH_CHECK(x.scalar_type() == weight_type); - TORCH_CHECK(x.is_cuda()); - TORCH_CHECK(x.is_contiguous()); - CHECK_SHAPE(x, batch_size, dim, n_chunks, 2 * dstate); - } - - at::Tensor du = torch::empty_like(u); - at::Tensor ddelta = torch::empty_like(delta); - at::Tensor dA = torch::zeros_like(A); - at::Tensor dB = !is_variable_B ? torch::zeros_like(B) : torch::zeros_like(B, B.options().dtype(torch::kFloat32)); - at::Tensor dC = !is_variable_C ? torch::zeros_like(C) : torch::zeros_like(C, C.options().dtype(torch::kFloat32)); - at::Tensor dD; - if (D_.has_value()) { dD = torch::zeros_like(D_.value()); } - at::Tensor ddelta_bias; - if (delta_bias_.has_value()) { ddelta_bias = torch::zeros_like(delta_bias_.value()); } - - SSMParamsBwd params; - set_ssm_params_bwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, - u, delta, A, B, C, z, out, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x_.has_value() ? x_.value().data_ptr() : nullptr, - dout, du, ddelta, dA, dB, dC, dz, - D_.has_value() ? dD.data_ptr() : nullptr, - delta_bias_.has_value() ? ddelta_bias.data_ptr() : nullptr, - has_z, delta_softplus, recompute_out_z); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; - auto stream = at::cuda::getCurrentCUDAStream().stream(); - DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_bwd", [&] { - DISPATCH_WTYPE_FLOAT_AND_COMPLEX(A.scalar_type(), "selective_scan_bwd", [&] { - selective_scan_bwd_cuda(params, stream); - }); - }); - std::vector result = {du, ddelta, dA, dB.to(B.dtype()), dC.to(C.dtype()), dD, ddelta_bias}; - if (has_z) { result.push_back(dz); } - if (recompute_out_z) { result.push_back(out_z); } - return result; -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("fwd", &selective_scan_fwd, "Selective scan forward"); - m.def("bwd", &selective_scan_bwd, "Selective scan backward"); -} diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan.h b/src/transformers/kernels/mamba/selective_scan/selective_scan.h deleted file mode 100644 index e2c7bcdbd5ddad..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan.h +++ /dev/null @@ -1,101 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMScanParamsBase { - using index_t = uint32_t; - - int batch, seqlen, n_chunks; - index_t a_batch_stride; - index_t b_batch_stride; - index_t out_batch_stride; - - // Common data pointers. - void *__restrict__ a_ptr; - void *__restrict__ b_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct SSMParamsBase { - using index_t = uint32_t; - - int batch, dim, seqlen, dstate, n_groups, n_chunks; - int dim_ngroups_ratio; - bool is_variable_B; - bool is_variable_C; - - bool delta_softplus; - - index_t A_d_stride; - index_t A_dstate_stride; - index_t B_batch_stride; - index_t B_d_stride; - index_t B_dstate_stride; - index_t B_group_stride; - index_t C_batch_stride; - index_t C_d_stride; - index_t C_dstate_stride; - index_t C_group_stride; - index_t u_batch_stride; - index_t u_d_stride; - index_t delta_batch_stride; - index_t delta_d_stride; - index_t z_batch_stride; - index_t z_d_stride; - index_t out_batch_stride; - index_t out_d_stride; - index_t out_z_batch_stride; - index_t out_z_d_stride; - - // Common data pointers. - void *__restrict__ A_ptr; - void *__restrict__ B_ptr; - void *__restrict__ C_ptr; - void *__restrict__ D_ptr; - void *__restrict__ u_ptr; - void *__restrict__ delta_ptr; - void *__restrict__ delta_bias_ptr; - void *__restrict__ out_ptr; - void *__restrict__ x_ptr; - void *__restrict__ z_ptr; - void *__restrict__ out_z_ptr; -}; - -struct SSMParamsBwd: public SSMParamsBase { - index_t dout_batch_stride; - index_t dout_d_stride; - index_t dA_d_stride; - index_t dA_dstate_stride; - index_t dB_batch_stride; - index_t dB_group_stride; - index_t dB_d_stride; - index_t dB_dstate_stride; - index_t dC_batch_stride; - index_t dC_group_stride; - index_t dC_d_stride; - index_t dC_dstate_stride; - index_t du_batch_stride; - index_t du_d_stride; - index_t dz_batch_stride; - index_t dz_d_stride; - index_t ddelta_batch_stride; - index_t ddelta_d_stride; - - // Common data pointers. - void *__restrict__ dout_ptr; - void *__restrict__ dA_ptr; - void *__restrict__ dB_ptr; - void *__restrict__ dC_ptr; - void *__restrict__ dD_ptr; - void *__restrict__ du_ptr; - void *__restrict__ dz_ptr; - void *__restrict__ ddelta_ptr; - void *__restrict__ ddelta_bias_ptr; -}; diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_complex.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_complex.cu deleted file mode 100644 index c55f0e858af4eb..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_real.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_real.cu deleted file mode 100644 index 72adaf5cb13c64..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_bf16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_complex.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_complex.cu deleted file mode 100644 index df126d7c8d5f9f..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_real.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_real.cu deleted file mode 100644 index 3ff271b50eaff2..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp16_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_complex.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_complex.cu deleted file mode 100644 index 5554902342785b..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_complex.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_real.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_real.cu deleted file mode 100644 index a7ed642231da80..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_fp32_real.cu +++ /dev/null @@ -1,9 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_bwd_kernel.cuh" - -template void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_kernel.cuh b/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_kernel.cuh deleted file mode 100644 index 2ed101148a4b32..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_bwd_kernel.cuh +++ /dev/null @@ -1,531 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK -#include // For atomicAdd on complex - -#include -#include -#include -#include - -#include "selective_scan.h" -#include "selective_scan_common.h" -#include "reverse_scan.cuh" -#include "static_switch.h" - -template __device__ __forceinline__ scalar_t conj(scalar_t x); -template<> __device__ __forceinline__ float conj(float x) { return x; } -template<> __device__ __forceinline__ complex_t conj(complex_t x) { return std::conj(x); } - -template -struct Selective_Scan_bwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static constexpr int kNItems = kNItems_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsComplex = std::is_same_v; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kDeltaSoftplus = kDeltaSoftplus_; - static constexpr bool kHasZ = kHasZ_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads with float improves occupancy. - // For complex this would lead to massive register spilling, so we keep it at 2. - static constexpr int kMinBlocks = kNThreads == 128 && !kIsComplex ? 3 : 2; - using vec_t = typename BytesToType::Type; - using scan_t = std::conditional_t; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockLoadWeightT = cub::BlockLoad; - using BlockLoadWeightVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - // using BlockScanT = cub::BlockScan; - using BlockScanT = cub::BlockScan; - // using BlockScanT = cub::BlockScan; - using BlockReverseScanT = BlockReverseScan; - using BlockReduceT = cub::BlockReduce; - using BlockReduceFloatT = cub::BlockReduce; - using BlockReduceComplexT = cub::BlockReduce; - using BlockExchangeT = cub::BlockExchange; - static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemExchangeSize = (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockExchangeT::TempStorage); - static constexpr int kSmemReduceSize = sizeof(typename BlockReduceT::TempStorage); - static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize + kSmemReduceSize + sizeof(typename BlockScanT::TempStorage) + sizeof(typename BlockReverseScanT::TempStorage); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_bwd_kernel(SSMParamsBwd params) { - constexpr bool kIsComplex = Ktraits::kIsComplex; - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kDeltaSoftplus = Ktraits::kDeltaSoftplus; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_weight = reinterpret_cast(smem_); - auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_exchange = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - auto& smem_exchange1 = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize + sizeof(typename Ktraits::BlockExchangeT::TempStorage)); - auto& smem_reduce = *reinterpret_cast(reinterpret_cast(&smem_exchange) + Ktraits::kSmemExchangeSize); - auto& smem_reduce_float = *reinterpret_cast(&smem_reduce); - auto& smem_reduce_complex = *reinterpret_cast(&smem_reduce); - auto& smem_scan = *reinterpret_cast(reinterpret_cast(&smem_reduce) + Ktraits::kSmemReduceSize); - auto& smem_reverse_scan = *reinterpret_cast(reinterpret_cast(&smem_scan) + sizeof(typename Ktraits::BlockScanT::TempStorage)); - weight_t *smem_delta_a = reinterpret_cast(smem_ + Ktraits::kSmemSize); - scan_t *smem_running_postfix = reinterpret_cast(smem_delta_a + 2 * MAX_DSTATE + kNThreads); - weight_t *smem_da = reinterpret_cast(smem_running_postfix + MAX_DSTATE); - weight_t *smem_dbc = reinterpret_cast(smem_da + MAX_DSTATE); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * params.u_d_stride; - input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * params.delta_d_stride; - input_t *dout = reinterpret_cast(params.dout_ptr) + batch_id * params.dout_batch_stride - + dim_id * params.dout_d_stride; - weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * params.A_d_stride; - weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * params.B_d_stride; - input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * params.C_d_stride; - input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - weight_t *dA = reinterpret_cast(params.dA_ptr) + dim_id * params.dA_d_stride; - weight_t *dB = reinterpret_cast(params.dB_ptr) - + (!kIsVariableB ? dim_id * params.dB_d_stride : batch_id * (!kIsComplex ? params.dB_batch_stride : params.dB_batch_stride / 2) + group_id * params.dB_group_stride); - weight_t *dC = reinterpret_cast(params.dC_ptr) - + (!kIsVariableC ? dim_id * params.dC_d_stride : batch_id * (!kIsComplex ? params.dC_batch_stride : params.dC_batch_stride / 2) + group_id * params.dC_group_stride); - float *dD = params.dD_ptr == nullptr ? nullptr : reinterpret_cast(params.dD_ptr) + dim_id; - float D_val = params.D_ptr == nullptr ? 0 : reinterpret_cast(params.D_ptr)[dim_id]; - float *ddelta_bias = params.ddelta_bias_ptr == nullptr ? nullptr : reinterpret_cast(params.ddelta_bias_ptr) + dim_id; - float delta_bias = params.delta_bias_ptr == nullptr ? 0 : reinterpret_cast(params.delta_bias_ptr)[dim_id]; - scan_t *x = params.x_ptr == nullptr - ? nullptr - : reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id) * (params.n_chunks) * params.dstate; - float dD_val = 0; - float ddelta_bias_val = 0; - - constexpr int kChunkSize = kNThreads * kNItems; - u += (params.n_chunks - 1) * kChunkSize; - delta += (params.n_chunks - 1) * kChunkSize; - dout += (params.n_chunks - 1) * kChunkSize; - Bvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); - Cvar += (params.n_chunks - 1) * kChunkSize * (!kIsComplex ? 1 : 2); - for (int chunk = params.n_chunks - 1; chunk >= 0; --chunk) { - input_t u_vals[kNItems]; - input_t delta_vals_load[kNItems]; - input_t dout_vals_load[kNItems]; - __syncthreads(); - load_input(u, u_vals, smem_load, params.seqlen - chunk * kChunkSize); - u -= kChunkSize; - __syncthreads(); - load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - // Will reload delta at the same location if kDeltaSoftplus - if constexpr (!kDeltaSoftplus) { delta -= kChunkSize; } - __syncthreads(); - load_input(dout, dout_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - dout -= kChunkSize; - - float dout_vals[kNItems], delta_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dout_vals[i] = float(dout_vals_load[i]); - delta_vals[i] = float(delta_vals_load[i]) + delta_bias; - if constexpr (kDeltaSoftplus) { - delta_vals[i] = delta_vals[i] <= 20.f ? log1pf(expf(delta_vals[i])) : delta_vals[i]; - } - } - - if constexpr (kHasZ) { - input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * params.z_d_stride + chunk * kChunkSize; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * params.out_d_stride + chunk * kChunkSize; - input_t *dz = reinterpret_cast(params.dz_ptr) + batch_id * params.dz_batch_stride - + dim_id * params.dz_d_stride + chunk * kChunkSize; - input_t z_vals[kNItems], out_vals[kNItems]; - __syncthreads(); - load_input(z, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - __syncthreads(); - load_input(out, out_vals, smem_load, params.seqlen - chunk * kChunkSize); - float dz_vals[kNItems], z_silu_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - float z_sigmoid_val = 1.0f / (1.0f + expf(-z_val)); - z_silu_vals[i] = z_val * z_sigmoid_val; - dz_vals[i] = dout_vals[i] * float(out_vals[i]) * z_sigmoid_val - * (1.0f + z_val * (1.0f - z_sigmoid_val)); - dout_vals[i] *= z_silu_vals[i]; - } - __syncthreads(); - store_output(dz, dz_vals, smem_store, params.seqlen - chunk * kChunkSize); - if (params.out_z_ptr != nullptr) { // Recompute and store out_z - float out_z_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { out_z_vals[i] = float(out_vals[i]) * z_silu_vals[i]; } - // if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0) { - // printf("out_val=%f, z_silu_val = %f, out_z_val = %f\n", float(out_vals[0]), z_silu_vals[0], out_z_vals[0]); - // } - input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * params.out_z_d_stride + chunk * kChunkSize; - __syncthreads(); - store_output(out_z, out_z_vals, smem_store, params.seqlen - chunk * kChunkSize); - } - } - - float du_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { du_vals[i] = D_val * dout_vals[i]; } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { dD_val += dout_vals[i] * float(u_vals[i]); } - - float ddelta_vals[kNItems] = {0}; - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - const weight_t A_val = A[state_idx * params.A_dstate_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - weight_t A_scaled; - constexpr float kLog2e = M_LOG2E; - if constexpr (!kIsComplex) { - A_scaled = A_val * kLog2e; - } else { - A_scaled = complex_t(A_val.real_ * kLog2e, A_val.imag_); - } - weight_t B_val, C_val; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (!kIsVariableB) { - B_val = B[state_idx * params.B_dstate_stride]; - } else { - load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - } - if constexpr (!kIsVariableC) { - C_val = C[state_idx * params.C_dstate_stride]; - } else { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - } - // const weight_t A_val = smem_a[state_idx]; - scan_t thread_data[kNItems], thread_reverse_data[kNItems]; - if constexpr (!kIsComplex) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const float delta_a_exp = exp2f(delta_vals[i] * A_scaled); - thread_data[i] = make_float2(delta_a_exp, !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); - if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; - } else { - thread_reverse_data[i - 1].x = delta_a_exp; - } - thread_reverse_data[i].y = dout_vals[i] * - (!kIsVariableC - ? (!kIsVariableB ? B_val * C_val : C_val) - : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); - } - __syncthreads(); - thread_reverse_data[kNItems - 1].x = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) - : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; - // Initialize running total - scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float2(1.f, 0.f); - SSMScanPrefixCallbackOp postfix_op(running_postfix); - Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( - thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op - ); - if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } - weight_t dA_val = 0, dBC_val = 0; - weight_t dB_vals[kNItems], dC_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const float dx = thread_reverse_data[i].y; - const float ddelta_u = !kIsVariableB ? dx : dx * B_vals[i]; - du_vals[i] += ddelta_u * delta_vals[i]; - const float a = thread_data[i].y - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i]); - ddelta_vals[i] += ddelta_u * float(u_vals[i]) + dx * A_val * a; - dA_val += dx * delta_vals[i] * a; - if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val - dBC_val += dout_vals[i] * (!kIsVariableC ? thread_data[i].y : thread_data[i].y * C_vals[i]); - } else { // dBC_val is dC_val - dBC_val += dout_vals[i] * thread_data[i].y; - } - } - if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } - if constexpr (kIsVariableC) { - dC_vals[i] = dout_vals[i] * (!kIsVariableB ? thread_data[i].y * B_val : thread_data[i].y); - } - } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower - if constexpr (kIsVariableB || kIsVariableC) { - if constexpr (kIsVariableB) { - Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals, dB_vals); - } - if constexpr (kIsVariableC) { - auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; - Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals, dC_vals); - } - const int seqlen_remaining = params.seqlen - chunk * kChunkSize - threadIdx.x; - weight_t *dB_cur = dB + state_idx * params.dB_dstate_stride + chunk * kChunkSize + threadIdx.x; - weight_t *dC_cur = dC + state_idx * params.dC_dstate_stride + chunk * kChunkSize + threadIdx.x; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - if (i * kNThreads < seqlen_remaining) { - if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals[i]); } - if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals[i]); } - } - } - } - if constexpr (!kIsVariableB || !kIsVariableC) { - float2 dA_dBC_val = make_float2(dA_val, dBC_val); - dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); - dA_val = dA_dBC_val.x; - if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dA_dBC_val.y : dA_dBC_val.y + smem_dbc[state_idx]; - } - } else { - dA_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dA_val); - } - if (threadIdx.x == 0) { - smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; - } - } else { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - // Pytorch's implementation of complex exp (which calls thrust) is very slow - complex_t delta_a_exp = cexp2f(delta_vals[i] * A_scaled); - weight_t B_delta_u_val = !kIsVariableB ? delta_vals[i] * float(u_vals[i]) : B_vals[i] * delta_vals[i] * float(u_vals[i]); - thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); - if (i == 0) { - smem_delta_a[threadIdx.x == 0 ? state_idx + (chunk % 2) * MAX_DSTATE : threadIdx.x + 2 * MAX_DSTATE] = delta_a_exp; - } else { - thread_reverse_data[i - 1].x = delta_a_exp.real_; - thread_reverse_data[i - 1].y = -delta_a_exp.imag_; - } - complex_t dout_BC = 2 * dout_vals[i] - * conj(!kIsVariableC - ? (!kIsVariableB ? B_val * C_val : C_val) - : (!kIsVariableB ? B_val * C_vals[i] : C_vals[i])); - thread_reverse_data[i].z = dout_BC.real_; - thread_reverse_data[i].w = dout_BC.imag_; - } - __syncthreads(); - complex_t delta_a_exp = threadIdx.x == kNThreads - 1 - ? (chunk == params.n_chunks - 1 ? 1.f : smem_delta_a[state_idx + ((chunk + 1) % 2) * MAX_DSTATE]) - : smem_delta_a[threadIdx.x + 1 + 2 * MAX_DSTATE]; - thread_reverse_data[kNItems - 1].x = delta_a_exp.real_; - thread_reverse_data[kNItems - 1].y = -delta_a_exp.imag_; - // Initialize running total - scan_t running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? x[(chunk - 1) * params.dstate + state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - SSMScanPrefixCallbackOp prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - scan_t running_postfix = chunk < params.n_chunks - 1 && threadIdx.x % 32 == 0 ? smem_running_postfix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - SSMScanPrefixCallbackOp postfix_op(running_postfix); - Ktraits::BlockReverseScanT(smem_reverse_scan).InclusiveReverseScan( - thread_reverse_data, thread_reverse_data, SSMScanOp(), postfix_op - ); - if (threadIdx.x == 0) { smem_running_postfix[state_idx] = postfix_op.running_prefix; } - weight_t dA_val = 0, dBC_val = 0; - weight_t dB_vals[kNItems], dC_vals[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - complex_t x = complex_t(thread_data[i].z, thread_data[i].w); - complex_t dx = complex_t(thread_reverse_data[i].z, thread_reverse_data[i].w); - float ddelta_u = !kIsVariableB ? dx.real_ : (dx * conj(B_vals[i])).real_; - if constexpr (!kIsVariableB || !kIsVariableC) { - if constexpr (!kIsVariableB) { // dBC_val is dB_val - dBC_val += (2 * dout_vals[i]) * conj(!kIsVariableC ? x : x * C_vals[i]); - } else { // dBC_val is dC_val - dBC_val += (2 * dout_vals[i]) * conj(x); - } - } - const complex_t a_conj = conj(x - (!kIsVariableB ? delta_vals[i] * float(u_vals[i]) : delta_vals[i] * float(u_vals[i]) * B_vals[i])); - du_vals[i] += ddelta_u * delta_vals[i]; - ddelta_vals[i] += ddelta_u * float(u_vals[i]) + (dx * conj(A_val) * a_conj).real_; - dA_val += delta_vals[i] * dx * a_conj; - if constexpr (kIsVariableB) { dB_vals[i] = dx * delta_vals[i] * float(u_vals[i]); } - if constexpr (kIsVariableC) { - dC_vals[i] = (2 * dout_vals[i]) * conj(!kIsVariableB ? x * B_val : x); - } - } - // Block-exchange to make the atomicAdd's coalesced, otherwise they're much slower - if constexpr (kIsVariableB || kIsVariableC) { - float dB_vals_f[kNItems * 2], dC_vals_f[kNItems * 2]; - if constexpr (kIsVariableB) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dB_vals_f[i * 2] = dB_vals[i].real_; - dB_vals_f[i * 2 + 1] = dB_vals[i].imag_; - } - Ktraits::BlockExchangeT(smem_exchange).BlockedToStriped(dB_vals_f, dB_vals_f); - } - if constexpr (kIsVariableC) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - dC_vals_f[i * 2] = dC_vals[i].real_; - dC_vals_f[i * 2 + 1] = dC_vals[i].imag_; - } - auto &smem_exchange_C = !kIsVariableB ? smem_exchange : smem_exchange1; - Ktraits::BlockExchangeT(smem_exchange_C).BlockedToStriped(dC_vals_f, dC_vals_f); - } - const int seqlen_remaining = (params.seqlen - chunk * kChunkSize) * 2 - threadIdx.x; - float *dB_cur = reinterpret_cast(dB) + state_idx * params.dB_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; - float *dC_cur = reinterpret_cast(dC) + state_idx * params.dC_dstate_stride + chunk * kChunkSize * 2 + threadIdx.x; - #pragma unroll - for (int i = 0; i < kNItems * 2; ++i) { - if (i * kNThreads < seqlen_remaining) { - if constexpr (kIsVariableB) { gpuAtomicAdd(dB_cur + i * kNThreads, dB_vals_f[i]); } - if constexpr (kIsVariableC) { gpuAtomicAdd(dC_cur + i * kNThreads, dC_vals_f[i]); } - } - } - } - if constexpr (!kIsVariableB || !kIsVariableC) { - float4 dA_dBC_val = make_float4(dA_val.real_, dA_val.imag_, dBC_val.real_, dBC_val.imag_); - dA_dBC_val = Ktraits::BlockReduceT(smem_reduce).Sum(dA_dBC_val); - dA_val = complex_t(dA_dBC_val.x, dA_dBC_val.y); - dBC_val = complex_t(dA_dBC_val.z, dA_dBC_val.w); - if (threadIdx.x == 0) { - smem_dbc[state_idx] = chunk == params.n_chunks - 1 ? dBC_val : dBC_val + smem_dbc[state_idx]; - } - } else { - dA_val = Ktraits::BlockReduceComplexT(smem_reduce_complex).Sum(dA_val); - } - if (threadIdx.x == 0) { - smem_da[state_idx] = chunk == params.n_chunks - 1 ? dA_val : dA_val + smem_da[state_idx]; - } - } - } - - if constexpr (kDeltaSoftplus) { - __syncthreads(); - input_t delta_vals_load[kNItems]; - load_input(delta, delta_vals_load, smem_load, params.seqlen - chunk * kChunkSize); - delta -= kChunkSize; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float delta_val = float(delta_vals_load[i]) + delta_bias; - float delta_val_neg_exp = expf(-delta_val); - ddelta_vals[i] = delta_val <= 20.f - ? ddelta_vals[i] / (1.f + delta_val_neg_exp) - : ddelta_vals[i]; - } - } - for (int i = 0; i < kNItems; ++i) { ddelta_bias_val += ddelta_vals[i]; } - - input_t *du = reinterpret_cast(params.du_ptr) + batch_id * params.du_batch_stride - + dim_id * params.du_d_stride + chunk * kChunkSize; - input_t *ddelta = reinterpret_cast(params.ddelta_ptr) + batch_id * params.ddelta_batch_stride - + dim_id * params.ddelta_d_stride + chunk * kChunkSize; - __syncthreads(); - store_output(du, du_vals, smem_store, params.seqlen - chunk * kChunkSize); - __syncthreads(); - store_output(ddelta, ddelta_vals, smem_store, params.seqlen - chunk * kChunkSize); - - Bvar -= kChunkSize * (!kIsComplex ? 1 : 2); - Cvar -= kChunkSize * (!kIsComplex ? 1 : 2); - } - if (params.dD_ptr != nullptr) { - dD_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dD_val); - if (threadIdx.x == 0) { gpuAtomicAdd(dD, dD_val); } - } - if (params.ddelta_bias_ptr != nullptr) { - __syncthreads(); - ddelta_bias_val = Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(ddelta_bias_val); - if (threadIdx.x == 0) { gpuAtomicAdd(ddelta_bias, ddelta_bias_val); } - } - for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - gpuAtomicAdd(&(dA[state_idx * params.dA_dstate_stride]), smem_da[state_idx]); - weight_t dBC_val; - if (!kIsVariableB || !kIsVariableC) { dBC_val = smem_dbc[state_idx]; } - if constexpr (!kIsVariableB) { - gpuAtomicAdd(&(dB[state_idx * params.dB_dstate_stride]), - !kIsVariableC ? dBC_val * conj(C[state_idx * params.C_dstate_stride]) : dBC_val); - } - if constexpr (!kIsVariableC) { - gpuAtomicAdd(&(dC[state_idx * params.dC_dstate_stride]), - !kIsVariableB ? dBC_val * conj(B[state_idx * params.B_dstate_stride]) : dBC_val); - } - } -} - -template -void selective_scan_bwd_launch(SSMParamsBwd ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.delta_softplus, kDeltaSoftplus, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_bwd_kernel_traits; - // using Ktraits = Selective_Scan_bwd_kernel_traits; - // TODO: check this - constexpr int kSmemSize = Ktraits::kSmemSize + MAX_DSTATE * sizeof(typename Ktraits::scan_t) + (kNThreads + 4 * MAX_DSTATE) * sizeof(typename Ktraits::weight_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim); - auto kernel = &selective_scan_bwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); - }); -} - -template -void selective_scan_bwd_cuda(SSMParamsBwd ¶ms, cudaStream_t stream) { - if (params.seqlen <= 128) { - selective_scan_bwd_launch<32, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_bwd_launch<32, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_bwd_launch<32, 16, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_bwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_bwd_launch<128, 16, input_t, weight_t>(params, stream); - } -} \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_common.h b/src/transformers/kernels/mamba/selective_scan/selective_scan_common.h deleted file mode 100644 index 9140dcdf3b68ad..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_common.h +++ /dev/null @@ -1,221 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include // For scalar_value_type - -#define MAX_DSTATE 256 - -using complex_t = c10::complex; - -inline __device__ float2 operator+(const float2 & a, const float2 & b){ - return {a.x + b.x, a.y + b.y}; -} - -inline __device__ float3 operator+(const float3 &a, const float3 &b) { - return {a.x + b.x, a.y + b.y, a.z + b.z}; -} - -inline __device__ float4 operator+(const float4 & a, const float4 & b){ - return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template struct BytesToType {}; - -template<> struct BytesToType<16> { - using Type = uint4; - static_assert(sizeof(Type) == 16); -}; - -template<> struct BytesToType<8> { - using Type = uint64_t; - static_assert(sizeof(Type) == 8); -}; - -template<> struct BytesToType<4> { - using Type = uint32_t; - static_assert(sizeof(Type) == 4); -}; - -template<> struct BytesToType<2> { - using Type = uint16_t; - static_assert(sizeof(Type) == 2); -}; - -template<> struct BytesToType<1> { - using Type = uint8_t; - static_assert(sizeof(Type) == 1); -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Converter{ - static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { - #pragma unroll - for (int i = 0; i < N; ++i) { dst[i] = src[i]; } - } -}; - -template -struct Converter{ - static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast(src); - auto &dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } - } -}; - -#if __CUDA_ARCH__ >= 800 -template -struct Converter{ - static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { - static_assert(N % 2 == 0); - auto &src2 = reinterpret_cast(src); - auto &dst2 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } - } -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// From https://stackoverflow.com/questions/9860711/cucomplex-h-and-exp -// and https://forums.developer.nvidia.com/t/complex-number-exponential-function/24696 -__device__ __forceinline__ complex_t cexp2f(complex_t z) { - float t = exp2f(z.real_); - float c, s; - sincosf(z.imag_, &s, &c); - return complex_t(c * t, s * t); -} - -__device__ __forceinline__ complex_t cexpf(complex_t z) { - float t = expf(z.real_); - float c, s; - sincosf(z.imag_, &s, &c); - return complex_t(c * t, s * t); -} - -template struct SSMScanOp; - -template<> -struct SSMScanOp { - __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { - return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); - } -}; - -template<> -struct SSMScanOp { - __device__ __forceinline__ float4 operator()(const float4 &ab0, const float4 &ab1) const { - complex_t a0 = complex_t(ab0.x, ab0.y); - complex_t b0 = complex_t(ab0.z, ab0.w); - complex_t a1 = complex_t(ab1.x, ab1.y); - complex_t b1 = complex_t(ab1.z, ab1.w); - complex_t out_a = a1 * a0; - complex_t out_b = a1 * b0 + b1; - return make_float4(out_a.real_, out_a.imag_, out_b.real_, out_b.imag_); - } -}; - -// A stateful callback functor that maintains a running prefix to be applied -// during consecutive scan operations. -template struct SSMScanPrefixCallbackOp { - using scan_t = std::conditional_t, float2, float4>; - scan_t running_prefix; - // Constructor - __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} - // Callback operator to be entered by the first warp of threads in the block. - // Thread-0 is responsible for returning a value for seeding the block-wide scan. - __device__ scan_t operator()(scan_t block_aggregate) { - scan_t old_prefix = running_prefix; - running_prefix = SSMScanOp()(running_prefix, block_aggregate); - return old_prefix; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ void load_input(typename Ktraits::input_t *u, - typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadT::TempStorage &smem_load, - int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_vec = reinterpret_cast(smem_load); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadVecT(smem_load_vec).Load( - reinterpret_cast(u), - reinterpret_cast(u_vals) - ); - } else { - Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); - } -} - -template -inline __device__ void load_weight(typename Ktraits::input_t *Bvar, - typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, - int seqlen) { - constexpr int kNItems = Ktraits::kNItems; - if constexpr (!Ktraits::kIsComplex) { - typename Ktraits::input_t B_vals_load[kNItems]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast(Bvar), - reinterpret_cast(B_vals_load) - ); - } else { - Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - // #pragma unroll - // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } - Converter::to_float(B_vals_load, B_vals); - } else { - typename Ktraits::input_t B_vals_load[kNItems * 2]; - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( - reinterpret_cast(Bvar), - reinterpret_cast(B_vals_load) - ); - } else { - Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { B_vals[i] = complex_t(B_vals_load[i * 2], B_vals_load[i * 2 + 1]); } - } -} - -template -inline __device__ void store_output(typename Ktraits::input_t *out, - const float (&out_vals)[Ktraits::kNItems], - typename Ktraits::BlockStoreT::TempStorage &smem_store, - int seqlen) { - typename Ktraits::input_t write_vals[Ktraits::kNItems]; - #pragma unroll - for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_store_vec = reinterpret_cast(smem_store); - using vec_t = typename Ktraits::vec_t; - Ktraits::BlockStoreVecT(smem_store_vec).Store( - reinterpret_cast(out), - reinterpret_cast(write_vals) - ); - } else { - Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); - } -} diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_bf16.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_bf16.cu deleted file mode 100644 index 2b8615b1d522c1..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_bf16.cu +++ /dev/null @@ -1,10 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp16.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp16.cu deleted file mode 100644 index 015e2a0eff633d..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp16.cu +++ /dev/null @@ -1,10 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp32.cu b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp32.cu deleted file mode 100644 index c142fe0208ea78..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_fp32.cu +++ /dev/null @@ -1,10 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -// Split into multiple files to compile in paralell - -#include "selective_scan_fwd_kernel.cuh" - -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); -template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); \ No newline at end of file diff --git a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_kernel.cuh b/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_kernel.cuh deleted file mode 100644 index 440a209108bfe1..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/selective_scan_fwd_kernel.cuh +++ /dev/null @@ -1,345 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2023, Tri Dao. - ******************************************************************************/ - -#pragma once - -#include -#include -#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK - -#include -#include -#include - -#include "selective_scan.h" -#include "selective_scan_common.h" -#include "static_switch.h" - -template -struct Selective_Scan_fwd_kernel_traits { - static_assert(kNItems_ % 4 == 0); - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. - static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; - static constexpr int kNItems = kNItems_; - static constexpr int kNRows = kNRows_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : std::min(8, kNItems); - static_assert(kNItems % kNElts == 0); - static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsComplex = std::is_same_v; - static constexpr bool kIsEvenLen = kIsEvenLen_; - static constexpr bool kIsVariableB = kIsVariableB_; - static constexpr bool kIsVariableC = kIsVariableC_; - static constexpr bool kHasZ = kHasZ_; - - static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; - - using vec_t = typename BytesToType::Type; - using scan_t = std::conditional_t; - using BlockLoadT = cub::BlockLoad; - using BlockLoadVecT = cub::BlockLoad; - using BlockLoadWeightT = cub::BlockLoad; - using BlockLoadWeightVecT = cub::BlockLoad; - using BlockStoreT = cub::BlockStore; - using BlockStoreVecT = cub::BlockStore; - // using BlockScanT = cub::BlockScan; - // using BlockScanT = cub::BlockScan; - using BlockScanT = cub::BlockScan; - static constexpr int kSmemIOSize = std::max({sizeof(typename BlockLoadT::TempStorage), - sizeof(typename BlockLoadVecT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), - (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), - sizeof(typename BlockStoreT::TempStorage), - sizeof(typename BlockStoreVecT::TempStorage)}); - static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) -void selective_scan_fwd_kernel(SSMParamsBase params) { - constexpr bool kIsComplex = Ktraits::kIsComplex; - constexpr bool kIsVariableB = Ktraits::kIsVariableB; - constexpr bool kIsVariableC = Ktraits::kIsVariableC; - constexpr bool kHasZ = Ktraits::kHasZ; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNItems = Ktraits::kNItems; - constexpr int kNRows = Ktraits::kNRows; - constexpr bool kDirectIO = Ktraits::kDirectIO; - using input_t = typename Ktraits::input_t; - using weight_t = typename Ktraits::weight_t; - using scan_t = typename Ktraits::scan_t; - - // Shared memory. - extern __shared__ char smem_[]; - // cast to lvalue reference of expected type - // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); - // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); - // auto& smem_load = reinterpret_cast(smem_loadstorescan); - auto& smem_load = reinterpret_cast(smem_); - auto& smem_load_weight = reinterpret_cast(smem_); - auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); - auto& smem_store = reinterpret_cast(smem_); - auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); - // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); - // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); - scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); - - const int batch_id = blockIdx.x; - const int dim_id = blockIdx.y; - const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride - + dim_id * kNRows * params.u_d_stride; - input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride - + dim_id * kNRows * params.delta_d_stride; - weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; - weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; - input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; - weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; - input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - - float D_val[kNRows] = {0}; - if (params.D_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; - } - } - float delta_bias[kNRows] = {0}; - if (params.delta_bias_ptr != nullptr) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; - } - } - - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; - // } - - constexpr int kChunkSize = kNThreads * kNItems; - for (int chunk = 0; chunk < params.n_chunks; ++chunk) { - input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); - if constexpr (!kDirectIO) { __syncthreads(); } - load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); - } - u += kChunkSize; - delta += kChunkSize; - - float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float u_val = float(u_vals[r][i]); - delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; - if (params.delta_softplus) { - delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; - } - delta_u_vals[r][i] = delta_vals[r][i] * u_val; - out_vals[r][i] = D_val[r] * u_val; - } - } - - __syncthreads(); - for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { - weight_t A_val[kNRows]; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; - // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. - constexpr float kLog2e = M_LOG2E; - if constexpr (!kIsComplex) { - A_val[r] *= kLog2e; - } else { - A_val[r].real_ *= kLog2e; - } - } - // This variable holds B * C if both B and C are constant across seqlen. If only B varies - // across seqlen, this holds C. If only C varies across seqlen, this holds B. - // If both B and C vary, this is unused. - weight_t BC_val[kNRows]; - weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (kIsVariableB) { - load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - if constexpr (!kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - } - if constexpr (kIsVariableC) { - auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; - load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2)); - if constexpr (!kIsVariableB) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; - } - } - } - if constexpr (!kIsVariableB && !kIsVariableC) { - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; - } - } - - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if (r > 0) { __syncthreads(); } // Scan could be using the same smem - scan_t thread_data[kNItems]; - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - if constexpr (!kIsComplex) { - thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), - !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float2(1.f, 0.f); - } - } - } else { - // Pytorch's implementation of complex exp (which calls thrust) is very slow - complex_t delta_a_exp = cexp2f(delta_vals[r][i] * A_val[r]); - weight_t B_delta_u_val = !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]; - thread_data[i] = make_float4(delta_a_exp.real_, delta_a_exp.imag_, B_delta_u_val.real_, B_delta_u_val.imag_); - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { - thread_data[i] = make_float4(1.f, 0.f, 0.f, 0.f); - } - } - } - } - // Initialize running total - scan_t running_prefix; - if constexpr (!kIsComplex) { - // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); - } else { - running_prefix = chunk > 0 && threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float4(1.f, 0.f, 0.f, 0.f); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float4(1.f, 0.f, 0.f, 0.f); - } - SSMScanPrefixCallbackOp prefix_op(running_prefix); - Ktraits::BlockScanT(smem_scan).InclusiveScan( - thread_data, thread_data, SSMScanOp(), prefix_op - ); - // There's a syncthreads in the scan op, so we don't need to sync here. - // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. - if (threadIdx.x == 0) { - smem_running_prefix[state_idx] = prefix_op.running_prefix; - x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; - } - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - const weight_t C_val = !kIsVariableC - ? BC_val[r] - : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); - if constexpr (!kIsComplex) { - out_vals[r][i] += thread_data[i].y * C_val; - } else { - out_vals[r][i] += (complex_t(thread_data[i].z, thread_data[i].w) * C_val).real_ * 2; - } - } - } - } - - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; - __syncthreads(); - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - if constexpr (!kDirectIO) { - if (r > 0) { __syncthreads(); } - } - store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } - - if constexpr (kHasZ) { - input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride - + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; - input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride - + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; - #pragma unroll - for (int r = 0; r < kNRows; ++r) { - input_t z_vals[kNItems]; - __syncthreads(); - load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); - #pragma unroll - for (int i = 0; i < kNItems; ++i) { - float z_val = z_vals[i]; - out_vals[r][i] *= z_val / (1 + expf(-z_val)); - } - __syncthreads(); - store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); - } - } - - Bvar += kChunkSize * (!kIsComplex ? 1 : 2); - Cvar += kChunkSize * (!kIsComplex ? 1 : 2); - } -} - -template -void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { - // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block - // processing 1 row. - constexpr int kNRows = 1; - BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.is_variable_B, kIsVariableB, [&] { - BOOL_SWITCH(params.is_variable_C, kIsVariableC, [&] { - BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); - // printf("smem_size = %d\n", kSmemSize); - dim3 grid(params.batch, params.dim / kNRows); - auto kernel = &selective_scan_fwd_kernel; - if (kSmemSize >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); - }); - }); - }); -} - -template -void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { - if (params.seqlen <= 128) { - selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 256) { - selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 512) { - selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); - } else if (params.seqlen <= 1024) { - selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); - } else { - selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); - } -} diff --git a/src/transformers/kernels/mamba/selective_scan/static_switch.h b/src/transformers/kernels/mamba/selective_scan/static_switch.h deleted file mode 100644 index 7920ac045d0a2a..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/static_switch.h +++ /dev/null @@ -1,25 +0,0 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h -// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h - -#pragma once - -/// @param COND - a boolean expression to switch by -/// @param CONST_NAME - a name given for the constexpr bool variable. -/// @param ... - code to execute for true and false -/// -/// Usage: -/// ``` -/// BOOL_SWITCH(flag, BoolConst, [&] { -/// some_function(...); -/// }); -/// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() diff --git a/src/transformers/kernels/mamba/selective_scan/uninitialized_copy.cuh b/src/transformers/kernels/mamba/selective_scan/uninitialized_copy.cuh deleted file mode 100644 index 630622dddcc904..00000000000000 --- a/src/transformers/kernels/mamba/selective_scan/uninitialized_copy.cuh +++ /dev/null @@ -1,69 +0,0 @@ -/****************************************************************************** - * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the - * names of its contributors may be used to endorse or promote products - * derived from this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY - * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - ******************************************************************************/ - -#pragma once - -#include - -#include - - -namespace detail -{ - -#if defined(_NVHPC_CUDA) -template -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - // NVBug 3384810 - new (ptr) T(::cuda::std::forward(val)); -} -#else -template ::value, - int - >::type = 0> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - *ptr = ::cuda::std::forward(val); -} - -template ::value, - int - >::type = 0> -__host__ __device__ void uninitialized_copy(T *ptr, U &&val) -{ - new (ptr) T(::cuda::std::forward(val)); -} -#endif - -} // namespace detail diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index b5534865b9c7d5..d25f994d7fcc3e 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -66,6 +66,8 @@ class MambaConfig(PretrainedConfig): `rescale_every` layer. If set to 0 or a negative number, no rescale is done. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether or not to tie the word embeddings with the input token embeddings. + residual_in_fp32 (`bool`, *optional*, defaults to `False`): + Whether or not residuals should be in `float32`. use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last state. @@ -105,6 +107,7 @@ def __init__( use_conv_bias=True, hidden_act="silu", initializer_range=0.1, + residual_in_fp32=False, **kwargs, ): self.vocab_size = vocab_size @@ -124,6 +127,7 @@ def __init__( self.use_conv_bias = use_conv_bias self.hidden_act = hidden_act self.initializer_range = initializer_range + self.residual_in_fp32 = residual_in_fp32 super().__init__( tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index a41b523f42e08f..69a982a9af3a41 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -35,18 +35,21 @@ ) from .configuration_mamba import MambaConfig +logger = logging.get_logger(__name__) -if False and is_causal_conv_1d_available(): - from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +if is_mamba_ssm_available(): + from mamba_ssm.ops.triton.selective_state_update import selective_state_update + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn else: - causal_conv1d_update, causal_conv1d_fn = None, None + logger.warning_once(" `mamba_ssm` is not installed in your environnement. Make sure to install it following `src/transformers/kernels/mamba/Makefile`") + selective_state_update, selective_scan_fn = None, None -if False and is_kernel_compiled(): +if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: - selective_scan_update, selective_scan_fn = None, None + logger.warning_once(" `causal_conv1d` is not installed in your environnement. Make sure to install it: `src/transformers/kernels/mamba/Makefile`") + causal_conv1d_update, causal_conv1d_fn = None, None -logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m" _CONFIG_FOR_DOC = "MambaConfig" @@ -56,43 +59,6 @@ # See all Mamba models at https://huggingface.co/models?filter=mamba ] - -mamba_cuda_kernel = None - - -# Copied from transformers.models.mamba.modeling_mamba.load_mamba_cuda_kernel with mamba->MAMBA,mamba->mamba -def load_mamba_cuda_kernel(context_length): - from torch.utils.cpp_extension import load as load_kernel - - global mamba_cuda_kernel - - kernel_folder = Path(__file__).resolve().parent.parent.parent / "kernels" / "mamba" - cuda_kernel_files = [kernel_folder / f for f in ["mamba_op.cpp", "mamba_cuda.cu", "mamba_cuda_bf16.cu"]] - - # Only load the kernel if it's not been loaded yet or if we changed the context length - if mamba_cuda_kernel is not None and mamba_cuda_kernel.max_seq_length == context_length: - return - - logger.info(f"Loading CUDA kernel for MAMBA at context length of {context_length}.") - - flags = [ - "-res-usage", - "--maxrregcount 60", - "--use_fast_math", - "-O3", - "-Xptxas -O3", - "--extra-device-vectorization", - f"-DTmax={context_length}", - ] - mamba_cuda_kernel = load_kernel( - name=f"mamba_{context_length}", - sources=cuda_kernel_files, - verbose=(logging.get_verbosity() == logging.DEBUG), - extra_cuda_cflags=flags, - ) - mamba_cuda_kernel.max_seq_length = context_length - - class MambaMixer(nn.Module): def __init__(self, config, layer_idx): super().__init__() @@ -173,31 +139,13 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] if inference_params is not None and inference_params.seq_offset > 0: - y, _ = selective_scan_update( - ssm_state, - hidden_states, - discrete_time_step, - self.negA, - B, - C, - self.D, - z=gate, - dt_bias=self.dt_proj.bias, - dt_softplus=True, - ) + y = selective_state_update( + ssm_state, hidden_states, discrete_time_step, self.negA, B, C, self.D, z=gate, dt_bias=self.dt_proj.bias, dt_softplus=True + ) #fmt: skip else: - y, last_state = selective_scan_fn( - hidden_states, - discrete_time_step, - self.negA, - B, - C, - self.D.float(), - z=gate, - delta_bias=self.dt_proj.bias.float(), - delta_softplus=True, - return_last_state=True, - ) + y, last_state = selective_scan_fn( hidden_states, discrete_time_step, self.negA, B, C, self.D.float(), z=gate, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=True,) + if last_state is not None: + ssm_state = last_state # 4. Final linear projection attn_outputs = self.out_proj(y.transpose(1, 2)) @@ -312,15 +260,21 @@ def __init__(self, config, layer_idx): super().__init__() self.config = config self.layer_idx = layer_idx - # self.residual_in_fp32 = config.residual_in_fp32 + self.residual_in_fp32 = config.residual_in_fp32 self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mixer = MambaMixerSlow(config, layer_idx=layer_idx) + + if any(selective_scan_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update) is None: + MIXER_CLS = MambaMixerSlow + else: # CUDA is available and the kernels are also available + MIXER_CLS = MambaMixer + + self.mixer = MIXER_CLS(config, layer_idx=layer_idx) def forward(self, hidden_states, inference_params=None): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) - # if self.residual_in_fp32: - # residual = residual.to(torch.float32) + if self.residual_in_fp32: + residual = residual.to(torch.float32) hidden_states, conv_states, ssm_state = self.mixer(hidden_states, inference_params=inference_params) hidden_states = residual.to(torch.float32) + hidden_states From 1a1031011d349576831b4ab9decee08b7d315d16 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:27:31 +0100 Subject: [PATCH 028/116] let use use the original packages --- .../models/mamba/modeling_mamba.py | 1 + src/transformers/utils/import_utils.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 69a982a9af3a41..b8699d7313c057 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -34,6 +34,7 @@ logging, ) from .configuration_mamba import MambaConfig +from ...utils.import_utils import is_mamba_ssm_available, is_causal_conv1d_available logger = logging.get_logger(__name__) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 57b4e840414be0..430b10395248c5 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -305,6 +305,25 @@ def is_torch_cuda_available(): else: return False +def is_mamba_ssm_available(): + if is_torch_available(): + import torch + + if not torch.cuda.is_available(): + return False + else: + return _is_package_available("selective-scan-cuda") + return False + +def is_causal_conv1d_available(): + if is_torch_available(): + import torch + + if not torch.cuda.is_available(): + return False + return _is_package_available("causal_conv1d") + return False + def is_torch_mps_available(): if is_torch_available(): From 89fb490b551a29d56720ed7096a8995fec5fc596 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:29:08 +0100 Subject: [PATCH 029/116] styling --- .../models/auto/tokenization_auto.py | 2 +- .../models/mamba/modeling_mamba.py | 39 +++++++++++++------ src/transformers/utils/import_utils.py | 6 ++- 3 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py index a8086f4e8e19fa..82959969e23dc9 100644 --- a/src/transformers/models/auto/tokenization_auto.py +++ b/src/transformers/models/auto/tokenization_auto.py @@ -226,6 +226,7 @@ ("luke", ("LukeTokenizer", None)), ("lxmert", ("LxmertTokenizer", "LxmertTokenizerFast" if is_tokenizers_available() else None)), ("m2m_100", ("M2M100Tokenizer" if is_sentencepiece_available() else None, None)), + ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ("marian", ("MarianTokenizer" if is_sentencepiece_available() else None, None)), ( "mbart", @@ -368,7 +369,6 @@ ("roc_bert", ("RoCBertTokenizer", None)), ("roformer", ("RoFormerTokenizer", "RoFormerTokenizerFast" if is_tokenizers_available() else None)), ("rwkv", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), - ("mamba", (None, "GPTNeoXTokenizerFast" if is_tokenizers_available() else None)), ( "seamless_m4t", ( diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index b8699d7313c057..232aa4c6b3ff4f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -16,7 +16,6 @@ import math from dataclasses import dataclass -from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -33,22 +32,27 @@ add_start_docstrings_to_model_forward, logging, ) +from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_mamba import MambaConfig -from ...utils.import_utils import is_mamba_ssm_available, is_causal_conv1d_available + logger = logging.get_logger(__name__) if is_mamba_ssm_available(): - from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: - logger.warning_once(" `mamba_ssm` is not installed in your environnement. Make sure to install it following `src/transformers/kernels/mamba/Makefile`") + logger.warning_once( + " `mamba_ssm` is not installed in your environnement. Make sure to install it following `src/transformers/kernels/mamba/Makefile`" + ) selective_state_update, selective_scan_fn = None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: - logger.warning_once(" `causal_conv1d` is not installed in your environnement. Make sure to install it: `src/transformers/kernels/mamba/Makefile`") + logger.warning_once( + " `causal_conv1d` is not installed in your environnement. Make sure to install it: `src/transformers/kernels/mamba/Makefile`" + ) causal_conv1d_update, causal_conv1d_fn = None, None @@ -60,6 +64,7 @@ # See all Mamba models at https://huggingface.co/models?filter=mamba ] + class MambaMixer(nn.Module): def __init__(self, config, layer_idx): super().__init__() @@ -105,7 +110,6 @@ def __init__(self, config, layer_idx): self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.use_bias) def forward(self, hidden_states: torch.Tensor, inference_params=None): - batch_size, seq_len, _ = hidden_states.shape # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) @@ -141,10 +145,21 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): ssm_state = inference_params.ssm_states[self.layer_idx] if inference_params is not None and inference_params.seq_offset > 0: y = selective_state_update( - ssm_state, hidden_states, discrete_time_step, self.negA, B, C, self.D, z=gate, dt_bias=self.dt_proj.bias, dt_softplus=True - ) #fmt: skip + ssm_state, hidden_states, discrete_time_step, self.negA, B, C, self.D, z=gate, dt_bias=self.dt_proj.bias, dt_softplus=True + ) # fmt: skip else: - y, last_state = selective_scan_fn( hidden_states, discrete_time_step, self.negA, B, C, self.D.float(), z=gate, delta_bias=self.dt_proj.bias.float(), delta_softplus=True, return_last_state=True,) + y, last_state = selective_scan_fn( + hidden_states, + discrete_time_step, + self.negA, + B, + C, + self.D.float(), + z=gate, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=True, + ) if last_state is not None: ssm_state = last_state @@ -264,11 +279,11 @@ def __init__(self, config, layer_idx): self.residual_in_fp32 = config.residual_in_fp32 self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - if any(selective_scan_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update) is None: + if any(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update) is None: MIXER_CLS = MambaMixerSlow - else: # CUDA is available and the kernels are also available + else: # CUDA is available and the kernels are also available MIXER_CLS = MambaMixer - + self.mixer = MIXER_CLS(config, layer_idx=layer_idx) def forward(self, hidden_states, inference_params=None): diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 430b10395248c5..569d0dba9516e0 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -305,6 +305,7 @@ def is_torch_cuda_available(): else: return False + def is_mamba_ssm_available(): if is_torch_available(): import torch @@ -312,9 +313,10 @@ def is_mamba_ssm_available(): if not torch.cuda.is_available(): return False else: - return _is_package_available("selective-scan-cuda") + return _is_package_available("selective-scan-cuda") return False + def is_causal_conv1d_available(): if is_torch_available(): import torch @@ -323,7 +325,7 @@ def is_causal_conv1d_available(): return False return _is_package_available("causal_conv1d") return False - + def is_torch_mps_available(): if is_torch_available(): From 6cfe216ccd5bb0e037512ee2a150493a476b17a9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:30:29 +0100 Subject: [PATCH 030/116] nits --- src/transformers/models/mamba/modeling_mamba.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 232aa4c6b3ff4f..6502f58597104c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -404,19 +404,18 @@ class MambaCausalLMOutput(ModelOutput): one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. + Last known states. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None inference_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None + states: Optional[Tuple[torch.FloatTensor]] = None MAMBA_START_DOCSTRING = r""" From 1ecbd22317985fdb90a5027b6eacf6f832993918 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:31:30 +0100 Subject: [PATCH 031/116] fix some copieds --- docs/source/en/index.md | 1 + docs/source/en/tasks/language_modeling.md | 2 +- .../models/mamba/modeling_mamba.py | 1 - src/transformers/utils/dummy_pt_objects.py | 24 +++++++++++++++++++ 4 files changed, 26 insertions(+), 2 deletions(-) diff --git a/docs/source/en/index.md b/docs/source/en/index.md index 81dc97e97134c8..683b51ecb445d1 100644 --- a/docs/source/en/index.md +++ b/docs/source/en/index.md @@ -179,6 +179,7 @@ Flax), PyTorch, and/or TensorFlow. | [M-CTC-T](model_doc/mctct) | ✅ | ❌ | ❌ | | [M2M100](model_doc/m2m_100) | ✅ | ❌ | ❌ | | [MADLAD-400](model_doc/madlad-400) | ✅ | ✅ | ✅ | +| [Mamba](model_doc/mamba) | ✅ | ❌ | ❌ | | [Marian](model_doc/marian) | ✅ | ✅ | ✅ | | [MarkupLM](model_doc/markuplm) | ✅ | ❌ | ❌ | | [Mask2Former](model_doc/mask2former) | ✅ | ❌ | ❌ | diff --git a/docs/source/en/tasks/language_modeling.md b/docs/source/en/tasks/language_modeling.md index 4022867a027af7..2372b7bd45f81a 100644 --- a/docs/source/en/tasks/language_modeling.md +++ b/docs/source/en/tasks/language_modeling.md @@ -37,7 +37,7 @@ You can finetune other architectures for causal language modeling following the Choose one of the following architectures: -[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) +[BART](../model_doc/bart), [BERT](../model_doc/bert), [Bert Generation](../model_doc/bert-generation), [BigBird](../model_doc/big_bird), [BigBird-Pegasus](../model_doc/bigbird_pegasus), [BioGpt](../model_doc/biogpt), [Blenderbot](../model_doc/blenderbot), [BlenderbotSmall](../model_doc/blenderbot-small), [BLOOM](../model_doc/bloom), [CamemBERT](../model_doc/camembert), [CodeLlama](../model_doc/code_llama), [CodeGen](../model_doc/codegen), [CPM-Ant](../model_doc/cpmant), [CTRL](../model_doc/ctrl), [Data2VecText](../model_doc/data2vec-text), [ELECTRA](../model_doc/electra), [ERNIE](../model_doc/ernie), [Falcon](../model_doc/falcon), [Fuyu](../model_doc/fuyu), [GIT](../model_doc/git), [GPT-Sw3](../model_doc/gpt-sw3), [OpenAI GPT-2](../model_doc/gpt2), [GPTBigCode](../model_doc/gpt_bigcode), [GPT Neo](../model_doc/gpt_neo), [GPT NeoX](../model_doc/gpt_neox), [GPT NeoX Japanese](../model_doc/gpt_neox_japanese), [GPT-J](../model_doc/gptj), [LLaMA](../model_doc/llama), [Mamba](../model_doc/mamba), [Marian](../model_doc/marian), [mBART](../model_doc/mbart), [MEGA](../model_doc/mega), [Megatron-BERT](../model_doc/megatron-bert), [Mistral](../model_doc/mistral), [Mixtral](../model_doc/mixtral), [MPT](../model_doc/mpt), [MusicGen](../model_doc/musicgen), [MVP](../model_doc/mvp), [OpenLlama](../model_doc/open-llama), [OpenAI GPT](../model_doc/openai-gpt), [OPT](../model_doc/opt), [Pegasus](../model_doc/pegasus), [Persimmon](../model_doc/persimmon), [Phi](../model_doc/phi), [PLBart](../model_doc/plbart), [ProphetNet](../model_doc/prophetnet), [QDQBert](../model_doc/qdqbert), [Qwen2](../model_doc/qwen2), [Reformer](../model_doc/reformer), [RemBERT](../model_doc/rembert), [RoBERTa](../model_doc/roberta), [RoBERTa-PreLayerNorm](../model_doc/roberta-prelayernorm), [RoCBert](../model_doc/roc_bert), [RoFormer](../model_doc/roformer), [RWKV](../model_doc/rwkv), [Speech2Text2](../model_doc/speech_to_text_2), [StableLm](../model_doc/stablelm), [Transformer-XL](../model_doc/transfo-xl), [TrOCR](../model_doc/trocr), [Whisper](../model_doc/whisper), [XGLM](../model_doc/xglm), [XLM](../model_doc/xlm), [XLM-ProphetNet](../model_doc/xlm-prophetnet), [XLM-RoBERTa](../model_doc/xlm-roberta), [XLM-RoBERTa-XL](../model_doc/xlm-roberta-xl), [XLNet](../model_doc/xlnet), [X-MOD](../model_doc/xmod) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 6502f58597104c..082769e02a9acd 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -579,7 +579,6 @@ def forward( """, MAMBA_START_DOCSTRING, ) -# Copied from transformers.models.mamba.modeling_mamba.mambaForCausalLM with mamba->MAMBA,mamba->Mamba,mamba->mamba class MambaForCausalLM(MambaPreTrainedModel): _tied_weights_keys = ["head.weight"] diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 3b8316ba547294..9761ba2947e37c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -4991,6 +4991,30 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class MambaForCausalLM(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MambaModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + +class MambaPreTrainedModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class MarianForCausalLM(metaclass=DummyObject): _backends = ["torch"] From b937122e7a3c903b70367fc841d44e481d10923e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:35:54 +0100 Subject: [PATCH 032/116] update doc --- README.md | 1 + docs/source/en/model_doc/mamba.md | 79 ++----------------- .../models/mamba/configuration_mamba.py | 28 +++---- 3 files changed, 20 insertions(+), 88 deletions(-) diff --git a/README.md b/README.md index b7077ce61032ba..5768f60337fce7 100644 --- a/README.md +++ b/README.md @@ -413,6 +413,7 @@ Current number of checkpoints: ![](https://img.shields.io/endpoint?url=https://h 1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat. +1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. 1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 0a56eca918cc6e..70218937fe8186 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -22,19 +22,20 @@ rendered properly in your Markdown viewer. ## Overview -The Mamba model was proposed in []() by . - +The Mamba model was proposed in [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao. + +This model is a new paradigm architecture based on `state-space-models`. You can read more about the intuition behind these [here](blog). The abstract from the paper is the following: -** +*Foundation models, now powering most of the exciting applications in deep learning, are almost universally based on the Transformer architecture and its core attention module. Many subquadratic-time architectures such as linear attention, gated convolution and recurrent models, and structured state space models (SSMs) have been developed to address Transformers' computational inefficiency on long sequences, but they have not performed as well as attention on important modalities such as language. We identify that a key weakness of such models is their inability to perform content-based reasoning, and make several improvements. First, simply letting the SSM parameters be functions of the input addresses their weakness with discrete modalities, allowing the model to selectively propagate or forget information along the sequence length dimension depending on the current token. Second, even though this change prevents the use of efficient convolutions, we design a hardware-aware parallel algorithm in recurrent mode. We integrate these selective SSMs into a simplified end-to-end neural network architecture without attention or even MLP blocks (Mamba). Mamba enjoys fast inference (5× higher throughput than Transformers) and linear scaling in sequence length, and its performance improves on real data up to million-length sequences. As a general sequence model backbone, Mamba achieves state-of-the-art performance across several modalities such as language, audio, and genomics. On language modeling, our Mamba-3B model outperforms Transformers of the same size and matches Transformers twice its size, both in pretraining and downstream evaluation.* Tips: - +- in order to run the fast version of the model you should install `causal_conv1d` and `mamba` -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/). -The original code can be found [here](). +This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/ArthurZ). +The original code can be found [here](https://github.com/state-spaces/mamba). ## MambaConfig @@ -50,69 +51,3 @@ The original code can be found [here](). [[autodoc]] MambaForCausalLM - forward - -## Mamba attention and the recurrent formulas - -In a traditional auto-regressive Transformer, attention is written as - -$$O = \hbox{softmax}(QK^{T} / \sqrt{d}) V$$ - -with \\(Q\\), \\(K\\) and \\(V\\) are matrices of shape `seq_len x hidden_size` named query, key and value (they are actually bigger matrices with a batch dimension and an attention head dimension but we're only interested in the last two, which is where the matrix product is taken, so for the sake of simplicity we only consider those two). The product \\(QK^{T}\\) then has shape `seq_len x seq_len` and we can take the maxtrix product with \\(V\\) to get the output \\(O\\) of the same shape as the others. - -Replacing the softmax by its value gives: - -$$O_{i} = \frac{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}} V_{j}}{\sum_{j=1}^{i} e^{Q_{i} K_{j}^{T} / \sqrt{d}}}$$ - -Note that the entries in \\(QK^{T}\\) corresponding to \\(j > i\\) are masked (the sum stops at j) because the attention is not allowed to look at future tokens (only past ones). - -In comparison, the MAMBA attention is given by - -$$O_{i} = \sigma(R_{i}) \frac{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}} V_{j}}{\sum_{j=1}^{i} e^{W_{i-j} + K_{j}}}$$ - -where \\(R\\) is a new matrix called receptance by the author, \\(K\\) and \\(V\\) are still the key and value (\\(\sigma\\) here is the sigmoid function). \\(W\\) is a new vector that represents the position of the token and is given by - -$$W_{0} = u \hbox{ and } W_{k} = (k-1)w \hbox{ for } k \geq 1$$ - -with \\(u\\) and \\(w\\) learnable parameters called in the code `time_first` and `time_decay` respectively. The numerator and denominator can both be expressed recursively. Naming them \\(N_{i}\\) and \\(D_{i}\\) we have: - -$$N_{i} = e^{u + K_{i}} V_{i} + \hat{N}_{i} \hbox{ where } \hat{N}_{i} = e^{K_{i-1}} V_{i-1} + e^{w + K_{i-2}} V_{i-2} \cdots + e^{(i-2)w + K_{1}} V_{1}$$ - -so \\(\hat{N}_{i}\\) (called `numerator_state` in the code) satistfies - -$$\hat{N}_{0} = 0 \hbox{ and } \hat{N}_{j+1} = e^{K_{j}} V_{j} + e^{w} \hat{N}_{j}$$ - -and - -$$D_{i} = e^{u + K_{i}} + \hat{D}_{i} \hbox{ where } \hat{D}_{i} = e^{K_{i-1}} + e^{w + K_{i-2}} \cdots + e^{(i-2)w + K_{1}}$$ - -so \\(\hat{D}_{i}\\) (called `denominator_state` in the code) satistfies - -$$\hat{D}_{0} = 0 \hbox{ and } \hat{D}_{j+1} = e^{K_{j}} + e^{w} \hat{D}_{j}$$ - -The actual recurrent formula used are a tiny bit more complex, as for numerical stability we don't want to compute exponentials of big numbers. Usually the softmax is not computed as is, but the exponential of the maximum term is divided of the numerator and denominator: - -$$\frac{e^{x_{i}}}{\sum_{j=1}^{n} e^{x_{j}}} = \frac{e^{x_{i} - M}}{\sum_{j=1}^{n} e^{x_{j} - M}}$$ - -with \\(M\\) the maximum of all \\(x_{j}\\). So here on top of saving the numerator state (\\(\hat{N}\\)) and the denominator state (\\(\hat{D}\\)) we also keep track of the maximum of all terms encountered in the exponentials. So we actually use - -$$\tilde{N}_{i} = e^{-M_{i}} \hat{N}_{i} \hbox{ and } \tilde{D}_{i} = e^{-M_{i}} \hat{D}_{i}$$ - -defined by the following recurrent formulas: - -$$\tilde{N}_{0} = 0 \hbox{ and } \tilde{N}_{j+1} = e^{K_{j} - q} V_{j} + e^{w + M_{j} - q} \tilde{N}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ - -and - -$$\tilde{D}_{0} = 0 \hbox{ and } \tilde{D}_{j+1} = e^{K_{j} - q} + e^{w + M_{j} - q} \tilde{D}_{j} \hbox{ where } q = \max(K_{j}, w + M_{j})$$ - -and \\(M_{j+1} = q\\). With those, we can then compute - -$$N_{i} = e^{u + K_{i} - q} V_{i} + e^{M_{i}} \tilde{N}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ - -and - -$$D_{i} = e^{u + K_{i} - q} + e^{M_{i}} \tilde{D}_{i} \hbox{ where } q = \max(u + K_{i}, M_{i})$$ - -which finally gives us - -$$O_{i} = \sigma(R_{i}) \frac{N_{i}}{D_{i}}$$ \ No newline at end of file diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index d25f994d7fcc3e..4eb00c214ca066 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -39,37 +39,33 @@ class MambaConfig(PretrainedConfig): Args: - vocab_size (`int`, *optional*, defaults to 50277): + vocab_size (`int`, *optional*, defaults to 50280): Vocabulary size of the MAMBA model. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling [`MambaModel`]. - context_length (`int`, *optional*, defaults to 1024): - The maximum sequence length that this model can be be used with in a single forward (using it in RNN mode - lets use any sequence length). - hidden_size (`int`, *optional*, defaults to 4096): + hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the embeddings and hidden states. + state_size (``, *optional*, defaults to 16): num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the model. - attention_hidden_size (`int`, *optional*): - Dimensionality of the attention hidden states. Will default to `hidden_size` if unset. - intermediate_size (`int`, *optional*): - Dimensionality of the inner feed-forward layers. Will default to 4 times `hidden_size` if unset. layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): The epsilon to use in the layer normalization layers. - bos_token_id (`int`, *optional*, defaults to 0): + pad_token_id (``, *optional*, defaults to 0): + bos_token_id (`int`, *optional*, defaults to 1): The id of the beginning of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as GPTNeoX. - eos_token_id (`int`, *optional*, defaults to 0): + eos_token_id (`int`, *optional*, defaults to 2): The id of the end of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as GPTNeoX. - rescale_every (`int`, *optional*, defaults to 6): - At inference, the hidden states (and weights of the correponding output layers) are divided by 2 every - `rescale_every` layer. If set to 0 or a negative number, no rescale is done. + expand (``, *optional*, defaults to 2): + dt_rank (``, *optional*, defaults to `"auto"`): tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether or not to tie the word embeddings with the input token embeddings. + use_bias (``, *optional*, defaults to `False`): + use_conv_bias (``, *optional*, defaults to `True`): + hidden_act (``, *optional*, defaults to `"silu"`): + initializer_range (``, *optional*, defaults to 0.1): residual_in_fp32 (`bool`, *optional*, defaults to `False`): Whether or not residuals should be in `float32`. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last state. Example: From 9752dd03c540b8ad5b033b9602a5106f21cd0610 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:37:09 +0100 Subject: [PATCH 033/116] fix-copies --- README_es.md | 1 + README_fr.md | 1 + README_hd.md | 1 + README_ja.md | 1 + README_ko.md | 1 + README_zh-hans.md | 1 + README_zh-hant.md | 1 + 7 files changed, 7 insertions(+) diff --git a/README_es.md b/README_es.md index 9dfbf8931abada..ac2588269cf092 100644 --- a/README_es.md +++ b/README_es.md @@ -386,6 +386,7 @@ Número actual de puntos de control: ![](https://img.shields.io/endpoint?url=htt 1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat. +1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. 1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. diff --git a/README_fr.md b/README_fr.md index 75ebdd315f651d..54437a2af6874b 100644 --- a/README_fr.md +++ b/README_fr.md @@ -407,6 +407,7 @@ Nombre actuel de points de contrôle : ![](https://img.shields.io/endpoint?url=h 1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (de Facebook) a été publié dans l'article [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) de Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve et Ronan Collobert. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (de Facebook) a été publié dans l'article [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) de Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (de Google) a été publié dans l'article [MADLAD-400 : Un ensemble de données multilingue et de niveau document](https://arxiv.org/abs/2309.04662) de Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat. +1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (de Albert Gu and Tri Dao) publié dans l'article [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) parAlbert Gu and Tri Dao. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Des modèles de traduction automatique formés avec les données [OPUS](http://opus.nlpl.eu/) par Jörg Tiedemann. Le [cadre Marian](https://marian-nmt.github.io/) est en cours de développement par l'équipe Microsoft Translator. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (de Microsoft Research Asia) a été publié dans l'article [MarkupLM : Pré-entraînement de texte et de langage de balisage pour la compréhension visuellement riche de documents](https://arxiv.org/abs/2110.08518) de Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. 1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (de FAIR et UIUC) a été publié dans l'article [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) de Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. diff --git a/README_hd.md b/README_hd.md index 6402c3ee5eb7fc..f78900aa958984 100644 --- a/README_hd.md +++ b/README_hd.md @@ -360,6 +360,7 @@ conda install conda-forge::transformers 1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (फेसबुक से) साथ देने वाला पेपर [बियॉन्ड इंग्लिश-सेंट्रिक मल्टीलिंगुअल मशीन ट्रांसलेशन](https://arxiv.org/एब्स/2010.11125) एंजेला फैन, श्रुति भोसले, होल्गर श्वेन्क, झी मा, अहमद अल-किश्की, सिद्धार्थ गोयल, मनदीप बैनेस, ओनूर सेलेबी, गुइल्लाम वेन्जेक, विश्रव चौधरी, नमन गोयल, टॉम बर्च, विटाली लिपचिंस्की, सर्गेई एडुनोव, एडौर्ड द्वारा ग्रेव, माइकल औली, आर्मंड जौलिन द्वारा पोस्ट किया गया। 1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat. +1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao से) Albert Gu and Tri Dao. द्वाराअनुसंधान पत्र [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) के साथ जारी किया गया 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Jörg द्वारा [OPUS](http://opus.nlpl.eu/) डेटा से प्रशिक्षित मशीनी अनुवाद मॉडल पोस्ट किया गया टाइडेमैन द्वारा। [मैरियन फ्रेमवर्क](https://marian-nmt.github.io/) माइक्रोसॉफ्ट ट्रांसलेटर टीम द्वारा विकसित। 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (माइक्रोसॉफ्ट रिसर्च एशिया से) साथ में पेपर [मार्कअपएलएम: विजुअली-रिच डॉक्यूमेंट अंडरस्टैंडिंग के लिए टेक्स्ट और मार्कअप लैंग्वेज का प्री-ट्रेनिंग](https://arxiv.org/abs/2110.08518) जुनलॉन्ग ली, यिहेंग जू, लेई कुई, फुरु द्वारा वी द्वारा पोस्ट किया गया। 1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC से) Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. द्वाराअनुसंधान पत्र [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) के साथ जारी किया गया diff --git a/README_ja.md b/README_ja.md index bd8a058b7b1b96..94085ecba1f0c0 100644 --- a/README_ja.md +++ b/README_ja.md @@ -420,6 +420,7 @@ Flax、PyTorch、TensorFlowをcondaでインストールする方法は、それ 1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (Facebook から) Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert から公開された研究論文: [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (Facebook から) Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin から公開された研究論文: [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat. +1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao から) Albert Gu and Tri Dao. から公開された研究論文 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Jörg Tiedemann から. [OPUS](http://opus.nlpl.eu/) を使いながら学習された "Machine translation" (マシントランスレーション) モデル. [Marian Framework](https://marian-nmt.github.io/) はMicrosoft Translator Team が現在開発中です. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (Microsoft Research Asia から) Junlong Li, Yiheng Xu, Lei Cui, Furu Wei から公開された研究論文: [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC から) Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. から公開された研究論文 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) diff --git a/README_ko.md b/README_ko.md index 533ab4685bce09..b66044312fc919 100644 --- a/README_ko.md +++ b/README_ko.md @@ -335,6 +335,7 @@ Flax, PyTorch, TensorFlow 설치 페이지에서 이들을 conda로 설치하는 1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (Facebook 에서) Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 의 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 논문과 함께 발표했습니다. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (Facebook 에서) Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 의 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 논문과 함께 발표했습니다. 1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat. +1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (Albert Gu and Tri Dao 에서 제공)은 Albert Gu and Tri Dao.의 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)논문과 함께 발표했습니다. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (Microsoft Research Asia 에서) Junlong Li, Yiheng Xu, Lei Cui, Furu Wei 의 [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 논문과 함께 발표했습니다. 1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (FAIR and UIUC 에서 제공)은 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar.의 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527)논문과 함께 발표했습니다. diff --git a/README_zh-hans.md b/README_zh-hans.md index f2b9b38273bfba..4f3a93e0a8948d 100644 --- a/README_zh-hans.md +++ b/README_zh-hans.md @@ -359,6 +359,7 @@ conda install conda-forge::transformers 1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (来自 Facebook) 伴随论文 [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) 由 Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert 发布。 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (来自 Facebook) 伴随论文 [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) 由 Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin 发布。 1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat. +1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (来自 Albert Gu and Tri Dao) 伴随论文 [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) 由 Albert Gu and Tri Dao 发布。 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** 用 [OPUS](http://opus.nlpl.eu/) 数据训练的机器翻译模型由 Jörg Tiedemann 发布。[Marian Framework](https://marian-nmt.github.io/) 由微软翻译团队开发。 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (来自 Microsoft Research Asia) 伴随论文 [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) 由 Junlong Li, Yiheng Xu, Lei Cui, Furu Wei 发布。 1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (来自 FAIR and UIUC) 伴随论文 [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) 由 Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar 发布。 diff --git a/README_zh-hant.md b/README_zh-hant.md index 1d5155529aa0a3..8e75d915ead6d5 100644 --- a/README_zh-hant.md +++ b/README_zh-hant.md @@ -371,6 +371,7 @@ conda install conda-forge::transformers 1. **[M-CTC-T](https://huggingface.co/docs/transformers/model_doc/mctct)** (from Facebook) released with the paper [Pseudo-Labeling For Massively Multilingual Speech Recognition](https://arxiv.org/abs/2111.00161) by Loren Lugosch, Tatiana Likhomanenko, Gabriel Synnaeve, and Ronan Collobert. 1. **[M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)** (from Facebook) released with the paper [Beyond English-Centric Multilingual Machine Translation](https://arxiv.org/abs/2010.11125) by Angela Fan, Shruti Bhosale, Holger Schwenk, Zhiyi Ma, Ahmed El-Kishky, Siddharth Goyal, Mandeep Baines, Onur Celebi, Guillaume Wenzek, Vishrav Chaudhary, Naman Goyal, Tom Birch, Vitaliy Liptchinsky, Sergey Edunov, Edouard Grave, Michael Auli, Armand Joulin. 1. **[MADLAD-400](https://huggingface.co/docs/transformers/model_doc/madlad-400)** (from Google) released with the paper [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662) by Sneha Kudugunta, Isaac Caswell, Biao Zhang, Xavier Garcia, Christopher A. Choquette-Choo, Katherine Lee, Derrick Xin, Aditya Kusupati, Romi Stella, Ankur Bapna, Orhan Firat. +1. **[Mamba](https://huggingface.co/docs/transformers/main/model_doc/mamba)** (from Albert Gu and Tri Dao) released with the paper [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752) by Albert Gu and Tri Dao. 1. **[MarianMT](https://huggingface.co/docs/transformers/model_doc/marian)** Machine translation models trained using [OPUS](http://opus.nlpl.eu/) data by Jörg Tiedemann. The [Marian Framework](https://marian-nmt.github.io/) is being developed by the Microsoft Translator Team. 1. **[MarkupLM](https://huggingface.co/docs/transformers/model_doc/markuplm)** (from Microsoft Research Asia) released with the paper [MarkupLM: Pre-training of Text and Markup Language for Visually-rich Document Understanding](https://arxiv.org/abs/2110.08518) by Junlong Li, Yiheng Xu, Lei Cui, Furu Wei. 1. **[Mask2Former](https://huggingface.co/docs/transformers/model_doc/mask2former)** (from FAIR and UIUC) released with the paper [Masked-attention Mask Transformer for Universal Image Segmentation](https://arxiv.org/abs/2112.01527) by Bowen Cheng, Ishan Misra, Alexander G. Schwing, Alexander Kirillov, Rohit Girdhar. From a7881a3ca3715dddcfc7d5c5b89755b78b032cb2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:41:02 +0100 Subject: [PATCH 034/116] styling done --- src/transformers/models/mamba/modeling_mamba.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 082769e02a9acd..d62b46b2c0a61e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -586,7 +586,6 @@ def __init__(self, config): super().__init__(config) self.backbone = MambaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Initialize weights and apply final processing self.post_init() From f445b0daa63db4a1c97d12dc72359b48ae1abfcd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:49:28 +0100 Subject: [PATCH 035/116] nits --- src/transformers/models/mamba/modeling_mamba.py | 13 +++++++------ src/transformers/utils/import_utils.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index d62b46b2c0a61e..e746999957c47c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -115,7 +115,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation - if inference_params is not None and inference_params.seq_offset > 0: + if inference_params is not None and inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] hidden_states = causal_conv1d_update( hidden_states, @@ -127,7 +127,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): else: conv_state = nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) hidden_states = causal_conv1d_fn( - hidden_states=hidden_states, + x=hidden_states, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), bias=self.conv1d.bias, activation=self.activation, @@ -141,17 +141,18 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) + A = -torch.exp(self.A_log.float()) # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] - if inference_params is not None and inference_params.seq_offset > 0: + if inference_params is not None and inference_params.seqlen_offset > 0: y = selective_state_update( - ssm_state, hidden_states, discrete_time_step, self.negA, B, C, self.D, z=gate, dt_bias=self.dt_proj.bias, dt_softplus=True + ssm_state, hidden_states, discrete_time_step, A, B, C, self.D, z=gate, dt_bias=self.dt_proj.bias, dt_softplus=True ) # fmt: skip else: y, last_state = selective_scan_fn( hidden_states, discrete_time_step, - self.negA, + A, B, C, self.D.float(), @@ -279,7 +280,7 @@ def __init__(self, config, layer_idx): self.residual_in_fp32 = config.residual_in_fp32 self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - if any(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update) is None: + if any((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update)) is None: MIXER_CLS = MambaMixerSlow else: # CUDA is available and the kernels are also available MIXER_CLS = MambaMixer diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 569d0dba9516e0..ad406d49d4a77c 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -313,7 +313,7 @@ def is_mamba_ssm_available(): if not torch.cuda.is_available(): return False else: - return _is_package_available("selective-scan-cuda") + return _is_package_available("mamba-ssm") return False From 64ec8dd62fc4a30de58e7b7933b37f85fef44520 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 08:50:11 +0100 Subject: [PATCH 036/116] fix import check --- src/transformers/utils/import_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ad406d49d4a77c..c16e6748b49cc1 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -313,7 +313,7 @@ def is_mamba_ssm_available(): if not torch.cuda.is_available(): return False else: - return _is_package_available("mamba-ssm") + return _is_package_available("mamba_ssm") return False From e6e3ba8c602fb9d600f0772eaf5dc05dacae5004 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 09:07:07 +0100 Subject: [PATCH 037/116] run but wrong cuda ress --- .../models/mamba/modeling_mamba.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index e746999957c47c..7f324c4741f937 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -118,12 +118,12 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): if inference_params is not None and inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] hidden_states = causal_conv1d_update( - hidden_states, + hidden_states.squeeze(-1), conv_state, self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), self.conv1d.bias, self.activation, - ) + ).unsqueeze(-1) else: conv_state = nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) hidden_states = causal_conv1d_fn( @@ -146,15 +146,15 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): ssm_state = inference_params.ssm_states[self.layer_idx] if inference_params is not None and inference_params.seqlen_offset > 0: y = selective_state_update( - ssm_state, hidden_states, discrete_time_step, A, B, C, self.D, z=gate, dt_bias=self.dt_proj.bias, dt_softplus=True - ) # fmt: skip + ssm_state, hidden_states.squeeze(-1), discrete_time_step.squeeze(-1), A, B.squeeze(1), C.squeeze(1), self.D, z=gate.squeeze(-1), dt_bias=self.dt_proj.bias, dt_softplus=True + ).unsqueeze(-1) # fmt: skip else: y, last_state = selective_scan_fn( hidden_states, discrete_time_step, A, - B, - C, + B.transpose(1,2), + C.transpose(1,2), self.D.float(), z=gate, delta_bias=self.dt_proj.bias.float(), @@ -170,7 +170,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): class MambaCache: - def __init__(self, config, batch_size, conv_dtype=torch.float32, ssm_dtype=torch.float32, device=None): + def __init__(self, config, batch_size, conv_dtype=torch.float16, ssm_dtype=torch.float16, device=None): self.seqlen_offset = 0 d_model = config.hidden_size @@ -252,6 +252,9 @@ def forward(self, hidden_states, inference_params=None): y = y * self.act(gate) # (B D) # 4. Final linear projection attn_outputs = self.out_proj(y.transpose(1, 2)) + + inference_params.ssm_states[self.layer_idx].copy_(ssm_state) + return attn_outputs, None, ssm_state @@ -384,7 +387,7 @@ class MambaOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None inference_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None + states: Optional[Tuple[torch.FloatTensor]] = None @dataclass @@ -542,7 +545,7 @@ def forward( ) else: hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params) - inference_params.ssm_states[idx].copy_(ssm_state) + # inference_params.ssm_states[idx].copy_(ssm_state) # TODO maybe for torch.compile + graph do things here # inference_params.conv_states[idx].copy_(conv_state) @@ -569,7 +572,7 @@ def forward( last_hidden_state=hidden_states, inference_params=inference_params, hidden_states=all_hidden_states, - attentions=all_last_states, + states=all_last_states, ) @@ -662,7 +665,7 @@ def forward( ) hidden_states = mamba_outputs[0] - logits = self.lm_head(hidden_states) + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)) loss = None if labels is not None: @@ -684,5 +687,5 @@ def forward( logits=logits, inference_params=mamba_outputs.inference_params, hidden_states=mamba_outputs.hidden_states, - attentions=mamba_outputs.attentions, + states=mamba_outputs.states, ) From ed4eb4c88568134b67faef5ca132c57153356599 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 09:18:49 +0100 Subject: [PATCH 038/116] mamba CUDA works :) --- src/transformers/models/mamba/modeling_mamba.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 7f324c4741f937..8e354c61409d90 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -137,9 +137,8 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(hidden_states.transpose(1, 2)) time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) - discrete_time_step = self.dt_proj(time_step) + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1,2) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) - discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) A = -torch.exp(self.A_log.float()) # 3.c perform the recurrence y ← SSM(A, B, C)(x) @@ -162,7 +161,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): return_last_state=True, ) if last_state is not None: - ssm_state = last_state + inference_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection attn_outputs = self.out_proj(y.transpose(1, 2)) From 4c8fc48c0292e8755bdd3d5d16ee17cd82958fef Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 09:21:09 +0100 Subject: [PATCH 039/116] fix the fast path --- src/transformers/models/mamba/modeling_mamba.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 8e354c61409d90..28ab0771f668e8 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -137,7 +137,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(hidden_states.transpose(1, 2)) time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) - discrete_time_step = self.dt_proj.weight @ time_step.transpose(1,2) + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) A = -torch.exp(self.A_log.float()) @@ -152,8 +152,8 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): hidden_states, discrete_time_step, A, - B.transpose(1,2), - C.transpose(1,2), + B.transpose(1, 2), + C.transpose(1, 2), self.D.float(), z=gate, delta_bias=self.dt_proj.bias.float(), From 69e103fadb7fa5ed82e2b18c8d902272ef9642cf Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 09:25:58 +0100 Subject: [PATCH 040/116] config naming nits --- src/transformers/models/mamba/configuration_mamba.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 4eb00c214ca066..941fcc06b8933c 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -124,7 +124,7 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range self.residual_in_fp32 = residual_in_fp32 + self.tie_word_embeddings = tie_word_embeddings + self.dt_rank = dt_rank - super().__init__( - tie_word_embeddings=tie_word_embeddings, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs - ) + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) From ba21ff24dca85bf20904cd3538019f609281230c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 09:29:09 +0100 Subject: [PATCH 041/116] conversion script is not required at this stage --- .../mamba/convert_mamba_checkpoint_to_hf.py | 201 ------------------ 1 file changed, 201 deletions(-) delete mode 100644 src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py diff --git a/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py b/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py deleted file mode 100644 index 7f45e24ab77f7f..00000000000000 --- a/src/transformers/models/mamba/convert_mamba_checkpoint_to_hf.py +++ /dev/null @@ -1,201 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Convert a MAMBA checkpoint from BlinkDL to the Hugging Face format.""" - - -import argparse -import gc -import json -import os -import re - -import torch -from huggingface_hub import hf_hub_download - -from transformers import AutoModelForCausalLM, AutoTokenizer, MambaConfig, PreTrainedTokenizerFast -from transformers.modeling_utils import WEIGHTS_INDEX_NAME, shard_checkpoint - - -NUM_HIDDEN_LAYERS_MAPPING = { - "169M": 12, - "430M": 24, - "1B5": 24, - "3B": 32, - "7B": 32, - "14B": 40, -} - -HIDEN_SIZE_MAPPING = { - "169M": 768, - "430M": 1024, - "1B5": 2048, - "3B": 2560, - "7B": 4096, - "14B": 5120, -} - - -def convert_state_dict(state_dict): - state_dict_keys = list(state_dict.keys()) - for name in state_dict_keys: - weight = state_dict.pop(name) - # emb -> embedding - if name.startswith("emb."): - name = name.replace("emb.", "embeddings.") - # ln_0 -> pre_ln (only present at block 0) - if name.startswith("blocks.0.ln0"): - name = name.replace("blocks.0.ln0", "blocks.0.pre_ln") - # att -> attention - name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name) - # ffn -> feed_forward - name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name) - # time_mix_k -> time_mix_key and reshape - if name.endswith(".time_mix_k"): - name = name.replace(".time_mix_k", ".time_mix_key") - # time_mix_v -> time_mix_value and reshape - if name.endswith(".time_mix_v"): - name = name.replace(".time_mix_v", ".time_mix_value") - # time_mix_r -> time_mix_key and reshape - if name.endswith(".time_mix_r"): - name = name.replace(".time_mix_r", ".time_mix_receptance") - - if name != "head.weight": - name = "mamba." + name - - state_dict[name] = weight - return state_dict - - -def convert_rmkv_checkpoint_to_hf_format( - repo_id, checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None -): - # 1. If possible, build the tokenizer. - if tokenizer_file is None: - print("No `--tokenizer_file` provided, we will use the default tokenizer.") - vocab_size = 50277 - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") - else: - tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) - vocab_size = len(tokenizer) - tokenizer.save_pretrained(output_dir) - - # 2. Build the config - possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys()) - if size is None: - # Try to infer size from the checkpoint name - for candidate in possible_sizes: - if candidate in checkpoint_file: - size = candidate - break - if size is None: - raise ValueError("Could not infer the size, please provide it with the `--size` argument.") - if size not in possible_sizes: - raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.") - - config = MambaConfig( - vocab_size=vocab_size, - num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size], - hidden_size=HIDEN_SIZE_MAPPING[size], - ) - config.save_pretrained(output_dir) - - # 3. Download model file then convert state_dict - model_file = hf_hub_download(repo_id, checkpoint_file) - state_dict = torch.load(model_file, map_location="cpu") - state_dict = convert_state_dict(state_dict) - - # 4. Split in shards and save - shards, index = shard_checkpoint(state_dict) - for shard_file, shard in shards.items(): - torch.save(shard, os.path.join(output_dir, shard_file)) - - if index is not None: - save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME) - # Save the index as well - with open(save_index_file, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) - - # 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict - print( - "Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model." - ) - shard_files = list(shards.keys()) - - del state_dict - del shards - gc.collect() - - for shard_file in shard_files: - state_dict = torch.load(os.path.join(output_dir, shard_file)) - torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file)) - - del state_dict - gc.collect() - - if push_to_hub: - if model_name is None: - raise ValueError("Please provide a `model_name` to push the model to the Hub.") - model = AutoModelForCausalLM.from_pretrained(output_dir) - model.push_to_hub(model_name, max_shard_size="2GB") - tokenizer.push_to_hub(model_name) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # Required parameters - parser.add_argument( - "--repo_id", default=None, type=str, required=True, help="Repo ID from which to pull the checkpoint." - ) - parser.add_argument( - "--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo." - ) - parser.add_argument( - "--output_dir", default=None, type=str, required=True, help="Where to save the converted model." - ) - parser.add_argument( - "--tokenizer_file", - default=None, - type=str, - help="Path to the tokenizer file to use (if not provided, only the model is converted).", - ) - parser.add_argument( - "--size", - default=None, - type=str, - help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.", - ) - parser.add_argument( - "--push_to_hub", - action="store_true", - help="Push to the Hub the converted model.", - ) - parser.add_argument( - "--model_name", - default=None, - type=str, - help="Name of the pushed model on the Hub, including the username / organization.", - ) - - args = parser.parse_args() - convert_rmkv_checkpoint_to_hf_format( - args.repo_id, - args.checkpoint_file, - args.output_dir, - size=args.size, - tokenizer_file=args.tokenizer_file, - push_to_hub=args.push_to_hub, - model_name=args.model_name, - ) From fe537285a8a4ed692e2fa4fb3411a430d0980fbd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 14:04:55 +0100 Subject: [PATCH 042/116] finish fixing the fast path: generation make sense now! --- src/transformers/models/mamba/configuration_mamba.py | 5 ++++- src/transformers/models/mamba/modeling_mamba.py | 9 +++++---- tests/models/mamba/test_modeling_mamba.py | 6 +++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 941fcc06b8933c..9fd07106b43dd6 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -66,6 +66,8 @@ class MambaConfig(PretrainedConfig): initializer_range (``, *optional*, defaults to 0.1): residual_in_fp32 (`bool`, *optional*, defaults to `False`): Whether or not residuals should be in `float32`. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the cache should be used. Example: @@ -104,6 +106,7 @@ def __init__( hidden_act="silu", initializer_range=0.1, residual_in_fp32=False, + use_cache=True, **kwargs, ): self.vocab_size = vocab_size @@ -113,7 +116,6 @@ def __init__( self.layer_norm_epsilon = layer_norm_epsilon self.d_inner = hidden_size * 2 self.conv_kernel = 4 - self.state_size = state_size self.expand = expand self.time_step_rank = math.ceil(self.hidden_size / 16) if dt_rank == "auto" else dt_rank self.bos_token_id = bos_token_id @@ -126,5 +128,6 @@ def __init__( self.residual_in_fp32 = residual_in_fp32 self.tie_word_embeddings = tie_word_embeddings self.dt_rank = dt_rank + self.use_cache=use_cache super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 28ab0771f668e8..1cfb5301537b95 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -126,6 +126,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): ).unsqueeze(-1) else: conv_state = nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) + inference_params.conv_states[self.layer_idx].copy_(conv_state) hidden_states = causal_conv1d_fn( x=hidden_states, weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), @@ -169,7 +170,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): class MambaCache: - def __init__(self, config, batch_size, conv_dtype=torch.float16, ssm_dtype=torch.float16, device=None): + def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.seqlen_offset = 0 d_model = config.hidden_size @@ -178,11 +179,11 @@ def __init__(self, config, batch_size, conv_dtype=torch.float16, ssm_dtype=torch d_conv = config.conv_kernel self.conv_states = { - i: torch.zeros(batch_size, d_model * expand, d_conv, device=device, dtype=conv_dtype) + i: torch.zeros(batch_size, d_model * expand, d_conv, device=device, dtype=dtype) for i in range(config.num_hidden_layers) } self.ssm_states = { - i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=ssm_dtype) + i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=dtype) for i in range(config.num_hidden_layers) } @@ -525,7 +526,7 @@ def forward( inputs_embeds = self.embeddings(input_ids) if use_cache and inference_params is None: - inference_params = MambaCache(self.config, inputs_embeds.size(0), device=inputs_embeds.device) + inference_params = MambaCache(self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype) if self.gradient_checkpointing and self.training: if use_cache: diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 5fb9f33a13ab74..49b293eb99b651 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -423,7 +423,7 @@ def test_model_from_pretrained(self): class MambaIntegrationTests(unittest.TestCase): def setUp(self): - self.model_id = "state-spaces/mamba-2.8b" + self.model_id = "ArthurZ/mamba-2.8b" self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) def test_simple_generate(self): @@ -435,7 +435,7 @@ def test_simple_generate(self): tokenizer.pad_token = tokenizer.eos_token model = MambaForCausalLM.from_pretrained( - "state-spaces/mamba-130m", vocab_size=50280, num_hidden_layers=24, torch_dtype=torch.float16 + "ArthurZ/mamba-130m", vocab_size=50277, num_hidden_layers=64, torch_dtype=torch.float16, hidden_size=2560 ) model.to(torch_device) model.config.use_cache = True @@ -460,7 +460,7 @@ def test_simple_generate_bf16(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) - model = MambaForCausalLM.from_pretrained(self.model_id, torch_dtype=torch.bfloat16).to(torch_device) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16).to(torch_device) output = model.generate(input_ids, max_new_tokens=10) output_sentence = self.tokenizer.decode(output[0].tolist()) From 9411169b90e984f9c46d16a0453c91c8fe34d8de Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 14:05:48 +0100 Subject: [PATCH 043/116] nit --- tests/models/mamba/test_modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 49b293eb99b651..e011d36ed1a9f7 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -456,7 +456,7 @@ def test_simple_generate(self): ], ) - def test_simple_generate_bf16(self): + def test_simple_generate_cuda_kernels(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) From c2c770967f14530e61a3db7064b960470b6ff8e3 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 14:08:00 +0100 Subject: [PATCH 044/116] Let's start working on the CIs --- src/transformers/models/mamba/configuration_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 9fd07106b43dd6..484c606adb97b5 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -128,6 +128,6 @@ def __init__( self.residual_in_fp32 = residual_in_fp32 self.tie_word_embeddings = tie_word_embeddings self.dt_rank = dt_rank - self.use_cache=use_cache + self.use_cache = use_cache super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) From 1e73ca91d507a7348095dca0228280aea164fe33 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 19 Feb 2024 14:13:29 +0100 Subject: [PATCH 045/116] style --- src/transformers/models/mamba/modeling_mamba.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 1cfb5301537b95..eb2dab9aa6d496 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -526,7 +526,9 @@ def forward( inputs_embeds = self.embeddings(input_ids) if use_cache and inference_params is None: - inference_params = MambaCache(self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype) + inference_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -546,7 +548,6 @@ def forward( else: hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params) # inference_params.ssm_states[idx].copy_(ssm_state) - # TODO maybe for torch.compile + graph do things here # inference_params.conv_states[idx].copy_(conv_state) if output_hidden_states: From 221322295afaf7da13e3e07fa38731ec907c2255 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 20 Feb 2024 02:38:01 +0100 Subject: [PATCH 046/116] better style --- .../models/mamba/modeling_mamba.py | 40 +++++++++++-------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index eb2dab9aa6d496..e3ec396a6f7af3 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -115,49 +115,57 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if inference_params is not None and inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] hidden_states = causal_conv1d_update( - hidden_states.squeeze(-1), - conv_state, - self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), - self.conv1d.bias, - self.activation, + hidden_states.squeeze(-1), conv_state, conv_weights, self.conv1d.bias, self.activation ).unsqueeze(-1) else: conv_state = nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) inference_params.conv_states[self.layer_idx].copy_(conv_state) hidden_states = causal_conv1d_fn( - x=hidden_states, - weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), - bias=self.conv1d.bias, + hidden_states, + conv_weights, + self.conv1d.bias, activation=self.activation, ) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - x_dbl = self.x_proj(hidden_states.transpose(1, 2)) + x_dbl = self.x_proj(hidden_states.transpose(1, 2)) # TODO find a better name for this one time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) - A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] if inference_params is not None and inference_params.seqlen_offset > 0: y = selective_state_update( - ssm_state, hidden_states.squeeze(-1), discrete_time_step.squeeze(-1), A, B.squeeze(1), C.squeeze(1), self.D, z=gate.squeeze(-1), dt_bias=self.dt_proj.bias, dt_softplus=True - ).unsqueeze(-1) # fmt: skip + ssm_state, + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + self.dt_proj.bias, + dt_softplus=True, + ).unsqueeze(-1) else: + B = B.transpose(1, 2) + C = C.transpose(1, 2) y, last_state = selective_scan_fn( hidden_states, discrete_time_step, A, - B.transpose(1, 2), - C.transpose(1, 2), + B, + C, self.D.float(), - z=gate, - delta_bias=self.dt_proj.bias.float(), + gate, + self.dt_proj.bias.float(), delta_softplus=True, return_last_state=True, ) From 2a020066402f01587a807342b317c0d516e6b07f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 20 Feb 2024 10:25:04 +0100 Subject: [PATCH 047/116] more nits --- .../models/mamba/modeling_mamba.py | 106 ++++++++---------- 1 file changed, 44 insertions(+), 62 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index e3ec396a6f7af3..af9731ff689af4 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -61,8 +61,12 @@ MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "state-spaces/mamba-130m", - # See all Mamba models at https://huggingface.co/models?filter=mamba -] + "state-spaces/mamba-370m", + "state-spaces/mamba-790m", + "state-spaces/mamba-1.4b", + "state-spaces/mamba-2.8b", + "state-spaces/mamba-2.8b-slimpj", +] # See all Mamba models at https://huggingface.co/models?filter=mamba class MambaMixer(nn.Module): @@ -73,9 +77,7 @@ def __init__(self, config, layer_idx): self.d_conv = config.conv_kernel self.expand = config.expand self.d_inner = int(self.expand * self.d_model) - self.time_step_rank = ( - math.ceil(self.d_model / 16) if config.time_step_rank == "auto" else config.time_step_rank - ) + self.time_step_rank = config.time_step_rank self.layer_idx = layer_idx self.conv1d = nn.Conv1d( @@ -96,17 +98,14 @@ def __init__(self, config, layer_idx): self.x_proj = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) # time step projection (discretization) self.dt_proj = nn.Linear(self.time_step_rank, self.d_inner, bias=True) + # S4D real initialization. These are not discretized! - # THe core is to load them, compute the discrete states, then write the updates state. - # Keeps the memory bounded + # THe core is to load them, compute the discrete states, then write the updates state. Keeps the memory bounded A = torch.arange(1, self.d_state + 1, dtype=torch.float32)[None, :].expand(self.d_inner, -1).contiguous() - A_log = torch.log(A) # Keep A_log in fp32 - self.A_log = nn.Parameter(A_log) - self.A_log._no_weight_decay = True + self.A_log = nn.Parameter(torch.log(A)) # TODO this parameter should be kept in float32. We don't have support for that I think # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 - self.D._no_weight_decay = True self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.use_bias) def forward(self, hidden_states: torch.Tensor, inference_params=None): @@ -118,24 +117,19 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if inference_params is not None and inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] - hidden_states = causal_conv1d_update( - hidden_states.squeeze(-1), conv_state, conv_weights, self.conv1d.bias, self.activation - ).unsqueeze(-1) + hidden_states = causal_conv1d_update(hidden_states.squeeze(-1), conv_state, conv_weights, self.conv1d.bias, self.activation) + hidden_states = hidden_states.unsqueeze(-1) else: conv_state = nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) inference_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = causal_conv1d_fn( - hidden_states, - conv_weights, - self.conv1d.bias, - activation=self.activation, - ) + hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(hidden_states.transpose(1, 2)) # TODO find a better name for this one time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) A = -torch.exp(self.A_log.float()) @@ -155,14 +149,12 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): dt_softplus=True, ).unsqueeze(-1) else: - B = B.transpose(1, 2) - C = C.transpose(1, 2) y, last_state = selective_scan_fn( hidden_states, discrete_time_step, A, - B, - C, + B.transpose(1, 2), + C.transpose(1, 2), self.D.float(), gate, self.dt_proj.bias.float(), @@ -330,6 +322,10 @@ def _init_weights(self, module): nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, std=self.config.initializer_range) + + # TODO make sure we properly init + # self.A_log._no_weight_decay = True + # self.D._no_weight_decay = True # # # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max # dt = torch.exp( @@ -384,18 +380,16 @@ class MambaOutput(ModelOutput): one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. + ssm_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_ssm_statess=True` is passed or when `config.output_ssm_states=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,sequence_length)`. - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. + State space model states weights after the selective scan. """ last_hidden_state: torch.FloatTensor = None inference_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None - states: Optional[Tuple[torch.FloatTensor]] = None + ssm_states: Optional[Tuple[torch.FloatTensor]] = None @dataclass @@ -416,18 +410,17 @@ class MambaCausalLMOutput(ModelOutput): one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. + ssm_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_ssm_statess=True` is passed or when `config.output_ssm_states=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,sequence_length)`. - Last known states. + State space model states weights after the selective scan. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None inference_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None - states: Optional[Tuple[torch.FloatTensor]] = None + ssm_states: Optional[Tuple[torch.FloatTensor]] = None MAMBA_START_DOCSTRING = r""" @@ -469,9 +462,8 @@ class MambaCausalLMOutput(ModelOutput): `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): If set to `True`, the last state is returned and can be used to quickly generate the next logits. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned - tensors for more detail. + output_ssm_states (`bool`, *optional*): + Whether or not to return the ssm_states of all `MambaMixer` layers. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. @@ -514,21 +506,19 @@ def forward( inputs_embeds: Optional[torch.LongTensor] = None, inference_params: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MambaOutput]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is None and inputs_embeds is None: - raise ValueError("You have to specify either input_ids or inputs_embeds") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) @@ -541,28 +531,22 @@ def forward( if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`" ) use_cache = False hidden_states = inputs_embeds - all_last_states = () if output_attentions else None all_hidden_states = () if output_hidden_states else None - for idx, layer in enumerate(self.layers): + for mixer_block in enumerate(self.layers): if self.gradient_checkpointing and self.training: - hidden_states, conv_state, ssm_state = self._gradient_checkpointing_func( - layer.__call__, hidden_states, inference_params - ) + hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, inference_params) else: - hidden_states, conv_state, ssm_state = layer(hidden_states, inference_params=inference_params) - # inference_params.ssm_states[idx].copy_(ssm_state) - # inference_params.conv_states[idx].copy_(conv_state) + hidden_states = mixer_block(hidden_states, inference_params=inference_params) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if output_attentions: - all_last_states = all_last_states + (ssm_state,) + # TODO this is not curcial as long as you know you are decoding vs not decoding inference_params.seqlen_offset += inputs_embeds.shape[1] hidden_states = self.norm_f(hidden_states) @@ -573,15 +557,14 @@ def forward( if not return_dict: return tuple( hidden_states - for hidden_states in [hidden_states, inference_params, all_hidden_states, all_last_states] + for hidden_states in [hidden_states, inference_params if use_cache else None, all_hidden_states] if hidden_states is not None ) return MambaOutput( last_hidden_state=hidden_states, - inference_params=inference_params, + inference_params=inference_params if use_cache else None, hidden_states=all_hidden_states, - states=all_last_states, ) @@ -631,7 +614,6 @@ def prepare_inputs_for_generation( if inference_params is not None: input_ids = input_ids[:, -1].unsqueeze(-1) - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and inference_params is None: model_inputs = {"inputs_embeds": inputs_embeds} else: @@ -652,7 +634,7 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, inference_params: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - output_attentions: Optional[bool] = None, + output_ssm_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MambaCausalLMOutput]: @@ -668,7 +650,7 @@ def forward( input_ids, inference_params=inference_params, inputs_embeds=inputs_embeds, - output_attentions=output_attentions, + output_ssm_states=output_ssm_states, output_hidden_states=output_hidden_states, return_dict=return_dict, ) @@ -696,5 +678,5 @@ def forward( logits=logits, inference_params=mamba_outputs.inference_params, hidden_states=mamba_outputs.hidden_states, - states=mamba_outputs.states, + ssm_states=mamba_outputs.ssm_states, ) From 8b0412f30cf82c26ae171df2e9eb7710c71b696b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 20 Feb 2024 10:25:08 +0100 Subject: [PATCH 048/116] test nit --- tests/models/mamba/test_modeling_mamba.py | 185 +++++----------------- 1 file changed, 44 insertions(+), 141 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index e011d36ed1a9f7..80467e544bdb2d 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -46,15 +46,12 @@ def __init__( batch_size=14, seq_length=7, is_training=True, - use_token_type_ids=False, - use_input_mask=True, use_labels=True, - use_mc_token_ids=True, vocab_size=99, hidden_size=32, num_hidden_layers=2, intermediate_size=37, - hidden_act="gelu", + hidden_act="silu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=512, @@ -68,17 +65,13 @@ def __init__( self.batch_size = batch_size self.seq_length = seq_length self.is_training = is_training - self.use_token_type_ids = use_token_type_ids - self.use_input_mask = use_input_mask self.use_labels = use_labels - self.use_mc_token_ids = use_mc_token_ids self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.intermediate_size = intermediate_size self.hidden_act = hidden_act self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob self.max_position_embeddings = max_position_embeddings self.type_vocab_size = type_vocab_size self.type_sequence_label_size = type_sequence_label_size @@ -90,25 +83,13 @@ def __init__( self.pad_token_id = vocab_size - 1 def get_large_model_config(self): - return MambaConfig.from_pretrained("sgugger/mamba-4-pile-7b") + return MambaConfig.from_pretrained("ArthurZ/mamba-130m") def prepare_config_and_inputs( self, gradient_checkpointing=False, scale_attn_by_inverse_layer_idx=False, reorder_and_upcast_attn=False ): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size) - input_mask = None - if self.use_input_mask: - input_mask = random_attention_mask([self.batch_size, self.seq_length]) - - token_type_ids = None - if self.use_token_type_ids: - token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size) - - mc_token_ids = None - if self.use_mc_token_ids: - mc_token_ids = ids_tensor([self.batch_size, self.num_choices], self.seq_length) - sequence_labels = None token_labels = None choice_labels = None @@ -126,10 +107,7 @@ def prepare_config_and_inputs( return ( config, input_ids, - input_mask, None, - token_type_ids, - mc_token_ids, sequence_labels, token_labels, choice_labels, @@ -144,8 +122,8 @@ def get_config( num_hidden_layers=self.num_hidden_layers, intermediate_size=self.intermediate_size, activation_function=self.hidden_act, - resid_pdrop=self.hidden_dropout_prob, - attn_pdrop=self.attention_probs_dropout_prob, + # resid_pdrop=self.hidden_dropout_prob, + # attn_pdrop=self.attention_probs_dropout_prob, n_positions=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, use_cache=True, @@ -153,8 +131,6 @@ def get_config( eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, gradient_checkpointing=gradient_checkpointing, - scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx, - reorder_and_upcast_attn=reorder_and_upcast_attn, ) def get_pipeline_config(self): @@ -166,32 +142,20 @@ def prepare_config_and_inputs_for_decoder(self): ( config, input_ids, - input_mask, - head_mask, - token_type_ids, - mc_token_ids, sequence_labels, token_labels, choice_labels, ) = self.prepare_config_and_inputs() - encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size]) - encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2) - return ( config, input_ids, - input_mask, - head_mask, - token_type_ids, sequence_labels, token_labels, choice_labels, - encoder_hidden_states, - encoder_attention_mask, ) - def create_and_check_mamba_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_mamba_model(self, config, input_ids, *args): config.output_hidden_states = True model = MambaModel(config=config) model.to(torch_device) @@ -202,7 +166,7 @@ def create_and_check_mamba_model(self, config, input_ids, input_mask, head_mask, self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(len(result.hidden_states), config.num_hidden_layers + 1) - def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_causl_lm(self, config, input_ids, *args): model = MambaForCausalLM(config) model.to(torch_device) model.eval() @@ -211,7 +175,7 @@ def create_and_check_causl_lm(self, config, input_ids, input_mask, head_mask, to self.parent.assertEqual(result.loss.shape, ()) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) - def create_and_check_state_equivalency(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): + def create_and_check_state_equivalency(self, config, input_ids, *args): model = MambaModel(config=config) model.to(torch_device) model.eval() @@ -229,7 +193,7 @@ def create_and_check_state_equivalency(self, config, input_ids, input_mask, head self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)) def create_and_check_forward_and_backwards( - self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + self, config, input_ids, *args, gradient_checkpointing=False ): model = MambaForCausalLM(config) model.to(torch_device) @@ -242,22 +206,8 @@ def create_and_check_forward_and_backwards( result.loss.backward() def prepare_config_and_inputs_for_common(self): - config_and_inputs = self.prepare_config_and_inputs() - - ( - config, - input_ids, - input_mask, - head_mask, - token_type_ids, - mc_token_ids, - sequence_labels, - token_labels, - choice_labels, - ) = config_and_inputs - + config, input_ids = self.prepare_config_and_inputs() inputs_dict = {"input_ids": input_ids} - return config, inputs_dict @@ -267,12 +217,11 @@ def prepare_config_and_inputs_for_common(self): @require_torch class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () - # all_generative_model_classes = (MambaForCausalLM,) if is_torch_available() else () - fx_compatible = False + fx_compatible = True test_missing_keys = False - test_model_parallel = False + test_model_parallel = True test_pruning = False - test_head_masking = False # Mamba does not support head masking + test_head_masking = False # Mamba does not have attention heads def setUp(self): self.model_tester = MambaModelTester(self) @@ -323,96 +272,27 @@ def test_initialization(self): for model_class in self.all_model_classes: model = model_class(config=config) for name, param in model.named_parameters(): - if "time_decay" in name: + if "A" in name: if param.requires_grad: self.assertTrue(param.data.max().item() == 3.0) self.assertTrue(param.data.min().item() == -5.0) - elif "time_first" in name: + elif "B" in name: if param.requires_grad: # check if it's a ones like self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) - elif any(x in name for x in ["time_mix_key", "time_mix_receptance"]): - if param.requires_grad: - self.assertInterval( - param.data, - [0.0, 1.0], - msg=f"Parameter {name} of model {model_class} seems not properly initialized", - ) - elif "time_mix_value" in name: - if param.requires_grad: - self.assertInterval( - param.data, - [0.0, 1.3], - msg=f"Parameter {name} of model {model_class} seems not properly initialized", - ) + # TODO handle initialization scheme! + + @unittest.skip("Mamba does not use attention equivalent test should be `test_ssm_outputs`") def test_attention_outputs(self): r""" Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models it has a shape `batch_size, seq_len, hidden_size`. """ - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.return_dict = True + pass - seq_len = getattr(self.model_tester, "seq_length", None) - - for model_class in self.all_model_classes: - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = False - config.return_dict = True - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - batch_size = inputs["input_ids"].shape[0] - with torch.no_grad(): - outputs = model(**inputs) - attentions = outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - # check that output_attentions also work using config - del inputs_dict["output_attentions"] - config.output_attentions = True - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - batch_size = inputs["input_ids"].shape[0] - with torch.no_grad(): - outputs = model(**inputs) - attentions = outputs.attentions - self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) - - self.assertListEqual( - list(attentions[0].shape[-3:]), - [batch_size, seq_len, config.hidden_size], - ) - out_len = len(outputs) - - # Check attention is always last and order is fine - inputs_dict["output_attentions"] = True - inputs_dict["output_hidden_states"] = True - model = model_class(config) - model.to(torch_device) - model.eval() - - inputs = self._prepare_for_class(inputs_dict, model_class) - batch_size = inputs["input_ids"].shape[0] - with torch.no_grad(): - outputs = model(**inputs) - - added_hidden_states = 1 - self.assertEqual(out_len + added_hidden_states, len(outputs)) - - self_attentions = outputs.attentions - - self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [batch_size, seq_len, config.hidden_size], - ) + def test_ssm_outputs(self): + pass @slow def test_model_from_pretrained(self): @@ -431,7 +311,7 @@ def test_simple_generate(self): from transformers import AutoTokenizer, MambaForCausalLM - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") + tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m") tokenizer.pad_token = tokenizer.eos_token model = MambaForCausalLM.from_pretrained( @@ -466,3 +346,26 @@ def test_simple_generate_cuda_kernels(self): output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) + + + def test_simple_generate_cuda_kernels(self): + expected_output = "Hello my name is Jasmine and I am a newbie to the" + + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(torch_device) + + output = model.generate(input_ids, max_new_tokens=10) + output_sentence = self.tokenizer.decode(output[0].tolist()) + + self.assertEqual(output_sentence, expected_output) + + def test_simple_generate_cuda_kernels(self): + expected_output = "Hello my name is Jasmine and I am a newbie to the" + + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(torch_device) + + output = model.generate(input_ids, max_new_tokens=10) + output_sentence = self.tokenizer.decode(output[0].tolist()) + + self.assertEqual(output_sentence, expected_output) From fbd6a2c0b6731b825d1fa2d8854d39993b91195c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 20 Feb 2024 11:39:38 +0100 Subject: [PATCH 049/116] quick fix for now --- src/transformers/models/mamba/modeling_mamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index af9731ff689af4..a3d3d09f8b1d00 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -637,6 +637,7 @@ def forward( output_ssm_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, #for now we need this for generation ) -> Union[Tuple, MambaCausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): From 823f11a73517c8eb7566dd42100a21fa5e3ec638 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 20 Feb 2024 11:50:10 +0100 Subject: [PATCH 050/116] nits --- .../models/mamba/modeling_mamba.py | 19 +++++++++++------- tests/models/mamba/test_modeling_mamba.py | 20 +++++++++++-------- 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index a3d3d09f8b1d00..a8ac9f6f6be6ac 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -14,7 +14,6 @@ # limitations under the License. """PyTorch MAMBA model.""" -import math from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union @@ -66,7 +65,7 @@ "state-spaces/mamba-1.4b", "state-spaces/mamba-2.8b", "state-spaces/mamba-2.8b-slimpj", -] # See all Mamba models at https://huggingface.co/models?filter=mamba +] # See all Mamba models at https://huggingface.co/models?filter=mamba class MambaMixer(nn.Module): @@ -102,7 +101,9 @@ def __init__(self, config, layer_idx): # S4D real initialization. These are not discretized! # THe core is to load them, compute the discrete states, then write the updates state. Keeps the memory bounded A = torch.arange(1, self.d_state + 1, dtype=torch.float32)[None, :].expand(self.d_inner, -1).contiguous() - self.A_log = nn.Parameter(torch.log(A)) # TODO this parameter should be kept in float32. We don't have support for that I think + self.A_log = nn.Parameter( + torch.log(A) + ) # TODO this parameter should be kept in float32. We don't have support for that I think # D "skip" parameter self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 @@ -117,7 +118,9 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if inference_params is not None and inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] - hidden_states = causal_conv1d_update(hidden_states.squeeze(-1), conv_state, conv_weights, self.conv1d.bias, self.activation) + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), conv_state, conv_weights, self.conv1d.bias, self.activation + ) hidden_states = hidden_states.unsqueeze(-1) else: conv_state = nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) @@ -290,7 +293,7 @@ def __init__(self, config, layer_idx): self.mixer = MIXER_CLS(config, layer_idx=layer_idx) - def forward(self, hidden_states, inference_params=None): + def forward(self, hidden_states, output_ssm_states=False, inference_params=None): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: @@ -539,7 +542,9 @@ def forward( all_hidden_states = () if output_hidden_states else None for mixer_block in enumerate(self.layers): if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, inference_params) + hidden_states = self._gradient_checkpointing_func( + mixer_block.__call__, hidden_states, inference_params + ) else: hidden_states = mixer_block(hidden_states, inference_params=inference_params) @@ -637,7 +642,7 @@ def forward( output_ssm_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, #for now we need this for generation + **kwargs, # for now we need this for generation ) -> Union[Tuple, MambaCausalLMOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 80467e544bdb2d..09b0408466ad0a 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -22,7 +22,7 @@ from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, random_attention_mask +from ...test_modeling_common import ModelTesterMixin, ids_tensor from ...test_pipeline_mixin import PipelineTesterMixin @@ -192,9 +192,7 @@ def create_and_check_state_equivalency(self, config, input_ids, *args): self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)) - def create_and_check_forward_and_backwards( - self, config, input_ids, *args, gradient_checkpointing=False - ): + def create_and_check_forward_and_backwards(self, config, input_ids, *args, gradient_checkpointing=False): model = MambaForCausalLM(config) model.to(torch_device) if gradient_checkpointing: @@ -206,7 +204,14 @@ def create_and_check_forward_and_backwards( result.loss.backward() def prepare_config_and_inputs_for_common(self): - config, input_ids = self.prepare_config_and_inputs() + ( + config, + input_ids, + _, + sequence_labels, + token_labels, + choice_labels, + ) = self.prepare_config_and_inputs() inputs_dict = {"input_ids": input_ids} return config, inputs_dict @@ -347,8 +352,7 @@ def test_simple_generate_cuda_kernels(self): self.assertEqual(output_sentence, expected_output) - - def test_simple_generate_cuda_kernels(self): + def test_simple_generate_cuda_kernels_mid(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) @@ -359,7 +363,7 @@ def test_simple_generate_cuda_kernels(self): self.assertEqual(output_sentence, expected_output) - def test_simple_generate_cuda_kernels(self): + def test_simple_generate_cuda_kernels_big(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) From 88896a9a8c54b1b864e551c2c6c15534fb20ed78 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 20 Feb 2024 20:41:29 +0900 Subject: [PATCH 051/116] nit --- src/transformers/models/mamba/modeling_mamba.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index a8ac9f6f6be6ac..a64507e84b9883 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -509,6 +509,7 @@ def forward( inputs_embeds: Optional[torch.LongTensor] = None, inference_params: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, + output_ssm_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MambaOutput]: @@ -540,7 +541,7 @@ def forward( hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None - for mixer_block in enumerate(self.layers): + for mixer_block in self.layers: if self.gradient_checkpointing and self.training: hidden_states = self._gradient_checkpointing_func( mixer_block.__call__, hidden_states, inference_params From 7f72ee817a3623f3e25f38d9af453dab04962e7c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 21 Feb 2024 10:49:16 +0900 Subject: [PATCH 052/116] nit --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index a64507e84b9883..af41d377005a2c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -527,7 +527,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) - if use_cache and inference_params is None: + if inference_params is None: inference_params = MambaCache( self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype ) From 0072a6c27126353b3791c5da2f7e8b2e38bb2cca Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 02:45:26 +0100 Subject: [PATCH 053/116] nit --- .../models/mamba/modeling_mamba.py | 28 ++++--------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index af41d377005a2c..8ede005c9d4c12 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -293,7 +293,7 @@ def __init__(self, config, layer_idx): self.mixer = MIXER_CLS(config, layer_idx=layer_idx) - def forward(self, hidden_states, output_ssm_states=False, inference_params=None): + def forward(self, hidden_states, inference_params=None): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: @@ -378,21 +378,18 @@ class MambaOutput(ModelOutput): inference_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. + + Includes both the State space model states weights after the selective scan, and the Convolutional states hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - ssm_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_ssm_statess=True` is passed or when `config.output_ssm_states=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,sequence_length)`. - - State space model states weights after the selective scan. """ last_hidden_state: torch.FloatTensor = None inference_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None - ssm_states: Optional[Tuple[torch.FloatTensor]] = None @dataclass @@ -413,17 +410,12 @@ class MambaCausalLMOutput(ModelOutput): one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. - ssm_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_ssm_statess=True` is passed or when `config.output_ssm_states=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,sequence_length)`. - - State space model states weights after the selective scan. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None inference_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None - ssm_states: Optional[Tuple[torch.FloatTensor]] = None MAMBA_START_DOCSTRING = r""" @@ -465,8 +457,6 @@ class MambaCausalLMOutput(ModelOutput): `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): If set to `True`, the last state is returned and can be used to quickly generate the next logits. - output_ssm_states (`bool`, *optional*): - Whether or not to return the ssm_states of all `MambaMixer` layers. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. @@ -509,7 +499,6 @@ def forward( inputs_embeds: Optional[torch.LongTensor] = None, inference_params: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, - output_ssm_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MambaOutput]: @@ -552,7 +541,6 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - # TODO this is not curcial as long as you know you are decoding vs not decoding inference_params.seqlen_offset += inputs_embeds.shape[1] hidden_states = self.norm_f(hidden_states) @@ -561,11 +549,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple( - hidden_states - for hidden_states in [hidden_states, inference_params if use_cache else None, all_hidden_states] - if hidden_states is not None - ) + return tuple(v for v in [hidden_states, inference_params, all_hidden_states] if v is not None) return MambaOutput( last_hidden_state=hidden_states, @@ -609,6 +593,7 @@ def _update_model_kwargs_for_generation( model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, + **kwargs ) -> Dict[str, Any]: model_kwargs["inference_params"] = outputs["inference_params"] return model_kwargs @@ -640,7 +625,6 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, inference_params: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, - output_ssm_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, # for now we need this for generation @@ -657,7 +641,6 @@ def forward( input_ids, inference_params=inference_params, inputs_embeds=inputs_embeds, - output_ssm_states=output_ssm_states, output_hidden_states=output_hidden_states, return_dict=return_dict, ) @@ -685,5 +668,4 @@ def forward( logits=logits, inference_params=mamba_outputs.inference_params, hidden_states=mamba_outputs.hidden_states, - ssm_states=mamba_outputs.ssm_states, ) From 7f6c56f4eca7680e32cff25e3fd4b405480a7a4c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 02:52:47 +0100 Subject: [PATCH 054/116] nits --- src/transformers/models/mamba/modeling_mamba.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 8ede005c9d4c12..b52b62034f734c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -168,8 +168,8 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): inference_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection - attn_outputs = self.out_proj(y.transpose(1, 2)) - return attn_outputs, conv_state, ssm_state + selected_states = self.out_proj(y.transpose(1, 2)) + return selected_states class MambaCache: @@ -254,11 +254,11 @@ def forward(self, hidden_states, inference_params=None): y = y + (hidden_states * self.D.to(hidden_states.dtype)[None, :, None]) y = y * self.act(gate) # (B D) # 4. Final linear projection - attn_outputs = self.out_proj(y.transpose(1, 2)) + selected_states = self.out_proj(y.transpose(1, 2)) inference_params.ssm_states[self.layer_idx].copy_(ssm_state) - return attn_outputs, None, ssm_state + return selected_states class MambaRMSNorm(nn.Module): @@ -299,9 +299,9 @@ def forward(self, hidden_states, inference_params=None): if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states, conv_states, ssm_state = self.mixer(hidden_states, inference_params=inference_params) + hidden_states = self.mixer(hidden_states, inference_params=inference_params) hidden_states = residual.to(torch.float32) + hidden_states - return hidden_states, conv_states, ssm_state + return hidden_states class MambaPreTrainedModel(PreTrainedModel): From f67c3533dce8c810ddecd4e1af2d67ab280689f6 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 02:53:52 +0100 Subject: [PATCH 055/116] update test rest --- tests/models/mamba/test_modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 09b0408466ad0a..827cbe14c71678 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -364,7 +364,7 @@ def test_simple_generate_cuda_kernels_mid(self): self.assertEqual(output_sentence, expected_output) def test_simple_generate_cuda_kernels_big(self): - expected_output = "Hello my name is Jasmine and I am a newbie to the" + expected_output = 'Hello my name is John. I am a student at the University of' input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(torch_device) From 2ab5a8651dbbdf56d683b526280f29bb24a39d24 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 02:59:12 +0100 Subject: [PATCH 056/116] fixup --- src/transformers/models/mamba/modeling_mamba.py | 2 +- tests/models/mamba/test_modeling_mamba.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index b52b62034f734c..13253c1adad2ab 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -593,7 +593,7 @@ def _update_model_kwargs_for_generation( model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False, standardize_cache_format: bool = False, - **kwargs + **kwargs, ) -> Dict[str, Any]: model_kwargs["inference_params"] = outputs["inference_params"] return model_kwargs diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 827cbe14c71678..139401e5f017dd 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -364,7 +364,7 @@ def test_simple_generate_cuda_kernels_mid(self): self.assertEqual(output_sentence, expected_output) def test_simple_generate_cuda_kernels_big(self): - expected_output = 'Hello my name is John. I am a student at the University of' + expected_output = "Hello my name is John. I am a student at the University of" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(torch_device) From 8920be32dfa3a3204a2b6e253860658ced9a8199 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 03:06:29 +0100 Subject: [PATCH 057/116] update test --- tests/models/mamba/test_modeling_mamba.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 139401e5f017dd..1d2485a965e720 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -227,6 +227,9 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_model_parallel = True test_pruning = False test_head_masking = False # Mamba does not have attention heads + pipeline_model_mapping = ( + {"feature-extraction": MambaModel, "text-generation": MambaForCausalLM} if is_torch_available() else {} + ) def setUp(self): self.model_tester = MambaModelTester(self) @@ -319,9 +322,7 @@ def test_simple_generate(self): tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m") tokenizer.pad_token = tokenizer.eos_token - model = MambaForCausalLM.from_pretrained( - "ArthurZ/mamba-130m", vocab_size=50277, num_hidden_layers=64, torch_dtype=torch.float16, hidden_size=2560 - ) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16) model.to(torch_device) model.config.use_cache = True input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(torch_device) @@ -342,7 +343,7 @@ def test_simple_generate(self): ) def test_simple_generate_cuda_kernels(self): - expected_output = "Hello my name is Jasmine and I am a newbie to the" + expected_output = "Hello my name is John of the Golden, and I am the Lord" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16).to(torch_device) From 87d0664fc4722718833a516e577e735507d144b0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 03:12:03 +0100 Subject: [PATCH 058/116] nit --- tests/models/mamba/test_modeling_mamba.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 1d2485a965e720..f7e422f6d61742 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -53,13 +53,13 @@ def __init__( intermediate_size=37, hidden_act="silu", hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, max_position_embeddings=512, type_vocab_size=16, type_sequence_label_size=2, num_labels=3, num_choices=4, scope=None, + tie_word_embeddings=True, ): self.parent = parent self.batch_size = batch_size @@ -81,6 +81,7 @@ def __init__( self.bos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1 self.pad_token_id = vocab_size - 1 + self.tie_word_embeddings = tie_word_embeddings def get_large_model_config(self): return MambaConfig.from_pretrained("ArthurZ/mamba-130m") @@ -122,8 +123,6 @@ def get_config( num_hidden_layers=self.num_hidden_layers, intermediate_size=self.intermediate_size, activation_function=self.hidden_act, - # resid_pdrop=self.hidden_dropout_prob, - # attn_pdrop=self.attention_probs_dropout_prob, n_positions=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, use_cache=True, @@ -131,6 +130,7 @@ def get_config( eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, gradient_checkpointing=gradient_checkpointing, + tie_word_embeddings=self.tie_word_embeddings ) def get_pipeline_config(self): From 8b00d7688e4c8fad156cc0d4c8c71b61b1c7ac28 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 04:05:52 +0100 Subject: [PATCH 059/116] some fixes --- .../models/mamba/configuration_mamba.py | 7 ++- .../models/mamba/modeling_mamba.py | 52 +++++++++---------- tests/models/mamba/test_modeling_mamba.py | 51 ++++++++++++++++-- 3 files changed, 75 insertions(+), 35 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 484c606adb97b5..29438c741f0ed2 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -99,7 +99,7 @@ def __init__( bos_token_id=1, eos_token_id=2, expand=2, - dt_rank="auto", + time_step_rank="auto", tie_word_embeddings=True, use_bias=False, use_conv_bias=True, @@ -114,10 +114,10 @@ def __init__( self.state_size = state_size self.num_hidden_layers = num_hidden_layers self.layer_norm_epsilon = layer_norm_epsilon - self.d_inner = hidden_size * 2 self.conv_kernel = 4 self.expand = expand - self.time_step_rank = math.ceil(self.hidden_size / 16) if dt_rank == "auto" else dt_rank + self.intermediate_size = int(self.expand * self.hidden_size) + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id @@ -127,7 +127,6 @@ def __init__( self.initializer_range = initializer_range self.residual_in_fp32 = residual_in_fp32 self.tie_word_embeddings = tie_word_embeddings - self.dt_rank = dt_rank self.use_cache = use_cache super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 13253c1adad2ab..9bf835e094fb3a 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -71,20 +71,20 @@ class MambaMixer(nn.Module): def __init__(self, config, layer_idx): super().__init__() - self.d_model = config.hidden_size - self.d_state = config.state_size - self.d_conv = config.conv_kernel + self.hidden_size = config.hidden_size + self.ssm_state_size = config.state_size + self.conv_kernel_size = config.conv_kernel self.expand = config.expand - self.d_inner = int(self.expand * self.d_model) + self.intermediate_size = config.intermediate_size self.time_step_rank = config.time_step_rank self.layer_idx = layer_idx self.conv1d = nn.Conv1d( - in_channels=self.d_inner, - out_channels=self.d_inner, + in_channels=self.intermediate_size, + out_channels=self.intermediate_size, bias=config.use_conv_bias, kernel_size=config.conv_kernel, - groups=self.d_inner, + groups=self.intermediate_size, padding=config.conv_kernel - 1, ) @@ -92,22 +92,22 @@ def __init__(self, config, layer_idx): self.act = ACT2FN[config.hidden_act] # projection of the input hidden states - self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=config.use_bias) + self.in_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=config.use_bias) # selective projection used to make dt, B and C input dependant - self.x_proj = nn.Linear(self.d_inner, self.time_step_rank + self.d_state * 2, bias=False) + self.x_proj = nn.Linear(self.intermediate_size, self.time_step_rank + self.ssm_state_size * 2, bias=False) # time step projection (discretization) - self.dt_proj = nn.Linear(self.time_step_rank, self.d_inner, bias=True) + self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) # S4D real initialization. These are not discretized! # THe core is to load them, compute the discrete states, then write the updates state. Keeps the memory bounded - A = torch.arange(1, self.d_state + 1, dtype=torch.float32)[None, :].expand(self.d_inner, -1).contiguous() + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :].expand(self.intermediate_size, -1).contiguous() self.A_log = nn.Parameter( torch.log(A) ) # TODO this parameter should be kept in float32. We don't have support for that I think # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.d_inner)) # Keep in fp32 - self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=config.use_bias) + self.D = nn.Parameter(torch.ones(self.intermediate_size)) # Keep in fp32 + self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) def forward(self, hidden_states: torch.Tensor, inference_params=None): # 1. Gated MLP's linear projection @@ -123,14 +123,14 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): ) hidden_states = hidden_states.unsqueeze(-1) else: - conv_state = nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) inference_params.conv_states[self.layer_idx].copy_(conv_state) hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(hidden_states.transpose(1, 2)) # TODO find a better name for this one - time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) + time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) @@ -176,17 +176,16 @@ class MambaCache: def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.seqlen_offset = 0 - d_model = config.hidden_size - d_state = config.state_size - expand = config.expand - d_conv = config.conv_kernel + intermediate_size = config.intermediate_size + ssm_state_size = config.state_size + conv_kernel_size = config.conv_kernel self.conv_states = { - i: torch.zeros(batch_size, d_model * expand, d_conv, device=device, dtype=dtype) + i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) for i in range(config.num_hidden_layers) } self.ssm_states = { - i: torch.zeros(batch_size, d_model * expand, d_state, device=device, dtype=dtype) + i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) for i in range(config.num_hidden_layers) } @@ -224,16 +223,16 @@ def forward(self, hidden_states, inference_params=None): else: inference_params.conv_states[self.layer_idx].copy_( - nn.functional.pad(hidden_states, (self.d_conv - hidden_states.shape[-1], 0)) + nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C x_dbl = self.x_proj(hidden_states.transpose(1, 2)) - time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.d_state, self.d_state], dim=-1) + time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) discrete_time_step = self.dt_proj(time_step) - A = -torch.exp(self.A_log.float()) # (d_inner, d_state) + A = -torch.exp(self.A_log.float()) # (intermediate_size, ssm_state_size) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) @@ -332,7 +331,7 @@ def _init_weights(self, module): # # # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max # dt = torch.exp( - # torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + # torch.rand(self.intermediate_size, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) # + math.log(dt_min) # ).clamp(min=dt_init_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -541,7 +540,8 @@ def forward( if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - inference_params.seqlen_offset += inputs_embeds.shape[1] + if use_cache: + inference_params.seqlen_offset += inputs_embeds.shape[1] hidden_states = self.norm_f(hidden_states) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index f7e422f6d61742..3f9449d17f03da 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -50,7 +50,7 @@ def __init__( vocab_size=99, hidden_size=32, num_hidden_layers=2, - intermediate_size=37, + intermediate_size=32, hidden_act="silu", hidden_dropout_prob=0.1, max_position_embeddings=512, @@ -183,11 +183,11 @@ def create_and_check_state_equivalency(self, config, input_ids, *args): outputs = model(input_ids) output_whole = outputs.last_hidden_state - outputs = model(input_ids[:, :2]) + outputs = model(input_ids[:, :2], use_cache=True) output_one = outputs.last_hidden_state # Using the state computed on the first inputs, we will get the same output - outputs = model(input_ids[:, 2:], state=outputs.state) + outputs = model(input_ids[:, 2:], inference_params=outputs.inference_params) output_two = outputs.last_hidden_state self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)) @@ -227,6 +227,7 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi test_model_parallel = True test_pruning = False test_head_masking = False # Mamba does not have attention heads + test_model_parallel = False pipeline_model_mapping = ( {"feature-extraction": MambaModel, "text-generation": MambaForCausalLM} if is_torch_available() else {} ) @@ -262,6 +263,35 @@ def assertInterval(self, member, container, msg=None): def test_config(self): self.config_tester.run_common_tests() + @unittest.skip("No attention in mamba") + def test_retain_grad_hidden_states_attentions(self): + pass + + # @require_torch_multi_gpu + def test_multi_gpu_data_parallel_forward(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # some params shouldn't be scattered by nn.DataParallel + # so just remove them if they are present. + blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"] + for k in blacklist_non_batched_params: + inputs_dict.pop(k, None) + + # move input tensors to cuda:O + for k, v in inputs_dict.items(): + if torch.is_tensor(v): + inputs_dict[k] = v.to(0) + + for model_class in self.all_model_classes: + model = model_class(config=config) + model.to(0) + model.eval() + + # Wrap model in nn.DataParallel + model = torch.nn.DataParallel(model) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + def test_mamba_model(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_mamba_model(*config_and_inputs) @@ -327,11 +357,11 @@ def test_simple_generate(self): model.config.use_cache = True input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(torch_device) - logits = model(input_ids=input_ids) + logits = model(input_ids=input_ids).logits EXPECTED_LOGITS = torch.tensor([ -6.7070, -24.7656, -6.4766, -6.0078, -9.7812, -13.0703, -11.4688, -10.6562, -9.3359, -9.4766, -9.1719, -7.9102, -13.0469, -8.7266, -8.4297, -8.4766, -9.1094, -11.5234, -11.1250, -11.7812, -12.1562, -12.8359, -12.1797, -13.4062, -13.6406, -13.4141, -13.6562, -9.2344, -7.9805, -7.2188, -9.9219, -9.1719, -7.8438, -9.1250, -10.1094, -10.2344, -10.2266, -9.7578, -11.0000, -10.6406], device='cuda:0',dtype=torch.float16) # fmt: skip - torch.testing.assert_allclose(logits, EXPECTED_LOGITS) + torch.testing.assert_allclose(logits[0,0,:40], EXPECTED_LOGITS) out = model.generate(input_ids, max_new_tokens=10) output_sentence = tokenizer.decode(out[0, :]) @@ -353,6 +383,17 @@ def test_simple_generate_cuda_kernels(self): self.assertEqual(output_sentence, expected_output) + def test_simple_generate_cuda_kernels_mid(self): + expected_output = "Hello my name is Jasmine and I am a newbie to the" + + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(torch_device) + + output = model.generate(input_ids, max_new_tokens=10) + output_sentence = self.tokenizer.decode(output[0].tolist()) + + self.assertEqual(output_sentence, expected_output) + def test_simple_generate_cuda_kernels_mid(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" From ca9835cf545f564b08704876085d6714683b4a3f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 04:09:50 +0100 Subject: [PATCH 060/116] nits --- src/transformers/models/mamba/modeling_mamba.py | 4 +++- tests/models/mamba/test_modeling_mamba.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 9bf835e094fb3a..61a385a4daa844 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -100,7 +100,9 @@ def __init__(self, config, layer_idx): # S4D real initialization. These are not discretized! # THe core is to load them, compute the discrete states, then write the updates state. Keeps the memory bounded - A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :].expand(self.intermediate_size, -1).contiguous() + A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = A.expand(self.intermediate_size, -1).contiguous() + self.A_log = nn.Parameter( torch.log(A) ) # TODO this parameter should be kept in float32. We don't have support for that I think diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 3f9449d17f03da..7029ca154c4269 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -130,7 +130,7 @@ def get_config( eos_token_id=self.eos_token_id, pad_token_id=self.pad_token_id, gradient_checkpointing=gradient_checkpointing, - tie_word_embeddings=self.tie_word_embeddings + tie_word_embeddings=self.tie_word_embeddings, ) def get_pipeline_config(self): @@ -270,23 +270,23 @@ def test_retain_grad_hidden_states_attentions(self): # @require_torch_multi_gpu def test_multi_gpu_data_parallel_forward(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - + # some params shouldn't be scattered by nn.DataParallel # so just remove them if they are present. blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"] for k in blacklist_non_batched_params: inputs_dict.pop(k, None) - + # move input tensors to cuda:O for k, v in inputs_dict.items(): if torch.is_tensor(v): inputs_dict[k] = v.to(0) - + for model_class in self.all_model_classes: model = model_class(config=config) model.to(0) model.eval() - + # Wrap model in nn.DataParallel model = torch.nn.DataParallel(model) with torch.no_grad(): @@ -361,7 +361,7 @@ def test_simple_generate(self): EXPECTED_LOGITS = torch.tensor([ -6.7070, -24.7656, -6.4766, -6.0078, -9.7812, -13.0703, -11.4688, -10.6562, -9.3359, -9.4766, -9.1719, -7.9102, -13.0469, -8.7266, -8.4297, -8.4766, -9.1094, -11.5234, -11.1250, -11.7812, -12.1562, -12.8359, -12.1797, -13.4062, -13.6406, -13.4141, -13.6562, -9.2344, -7.9805, -7.2188, -9.9219, -9.1719, -7.8438, -9.1250, -10.1094, -10.2344, -10.2266, -9.7578, -11.0000, -10.6406], device='cuda:0',dtype=torch.float16) # fmt: skip - torch.testing.assert_allclose(logits[0,0,:40], EXPECTED_LOGITS) + torch.testing.assert_allclose(logits[0, 0, :40], EXPECTED_LOGITS) out = model.generate(input_ids, max_new_tokens=10) output_sentence = tokenizer.decode(out[0, :]) @@ -372,7 +372,7 @@ def test_simple_generate(self): ], ) - def test_simple_generate_cuda_kernels(self): + def test_simple_generate_cuda_kernels_tiny(self): expected_output = "Hello my name is John of the Golden, and I am the Lord" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) @@ -383,7 +383,7 @@ def test_simple_generate_cuda_kernels(self): self.assertEqual(output_sentence, expected_output) - def test_simple_generate_cuda_kernels_mid(self): + def test_simple_generate_cuda_kernels_small(self): expected_output = "Hello my name is Jasmine and I am a newbie to the" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) From 796ef3efc8ba857e3109b89cd3438dc3af6fc9b4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 04:50:27 +0100 Subject: [PATCH 061/116] update test values --- tests/models/mamba/test_modeling_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 7029ca154c4269..9151ccc62a45f6 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -384,7 +384,7 @@ def test_simple_generate_cuda_kernels_tiny(self): self.assertEqual(output_sentence, expected_output) def test_simple_generate_cuda_kernels_small(self): - expected_output = "Hello my name is Jasmine and I am a newbie to the" + expected_output = 'Hello my name is\n\nI am a student of the art of' input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(torch_device) @@ -395,7 +395,7 @@ def test_simple_generate_cuda_kernels_small(self): self.assertEqual(output_sentence, expected_output) def test_simple_generate_cuda_kernels_mid(self): - expected_output = "Hello my name is Jasmine and I am a newbie to the" + expected_output = "Hello my name is John and I am a software engineer. I have" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(torch_device) From 170664ac4208de5dc727055ab0a760f832053065 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 04:51:21 +0100 Subject: [PATCH 062/116] fix styling --- tests/models/mamba/test_modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 9151ccc62a45f6..6856c733a549f6 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -384,7 +384,7 @@ def test_simple_generate_cuda_kernels_tiny(self): self.assertEqual(output_sentence, expected_output) def test_simple_generate_cuda_kernels_small(self): - expected_output = 'Hello my name is\n\nI am a student of the art of' + expected_output = "Hello my name is\n\nI am a student of the art of" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(torch_device) From 92493a05bb3bc52d83fc2e8369456e2735e986d4 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 04:52:12 +0100 Subject: [PATCH 063/116] nit --- src/transformers/models/mamba/configuration_mamba.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 29438c741f0ed2..198981a87663a8 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -100,7 +100,6 @@ def __init__( eos_token_id=2, expand=2, time_step_rank="auto", - tie_word_embeddings=True, use_bias=False, use_conv_bias=True, hidden_act="silu", @@ -126,7 +125,6 @@ def __init__( self.hidden_act = hidden_act self.initializer_range = initializer_range self.residual_in_fp32 = residual_in_fp32 - self.tie_word_embeddings = tie_word_embeddings self.use_cache = use_cache super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) From 854ebad54b49984ee86178e27b892c04bc3c6702 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 06:42:07 +0100 Subject: [PATCH 064/116] support peft --- src/transformers/models/mamba/modeling_mamba.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 61a385a4daa844..5236a734c1a8c4 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -140,6 +140,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if inference_params is not None and inference_params.seqlen_offset > 0: y = selective_state_update( ssm_state, @@ -150,7 +151,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): C[:, 0], self.D, gate[..., 0], - self.dt_proj.bias, + time_proj_bias, dt_softplus=True, ).unsqueeze(-1) else: @@ -162,7 +163,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): C.transpose(1, 2), self.D.float(), gate, - self.dt_proj.bias.float(), + time_proj_bias, delta_softplus=True, return_last_state=True, ) From aa0e6bb3ba0ffa3d19350bc01eabfe87eb54286a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 07:23:43 +0100 Subject: [PATCH 065/116] integrations tests require torchg --- tests/models/mamba/test_modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 6856c733a549f6..3db204a91e141e 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -338,7 +338,7 @@ def test_model_from_pretrained(self): model = MambaModel.from_pretrained(model_name) self.assertIsNotNone(model) - +@require_torch class MambaIntegrationTests(unittest.TestCase): def setUp(self): self.model_id = "ArthurZ/mamba-2.8b" From 3c1537ef7393073f61077e343f0e3159548854f8 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 07:25:25 +0100 Subject: [PATCH 066/116] also add slow markers --- tests/models/mamba/test_modeling_mamba.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 3db204a91e141e..f5a9dc76ed840c 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -383,6 +383,7 @@ def test_simple_generate_cuda_kernels_tiny(self): self.assertEqual(output_sentence, expected_output) + @slow def test_simple_generate_cuda_kernels_small(self): expected_output = "Hello my name is\n\nI am a student of the art of" @@ -394,6 +395,7 @@ def test_simple_generate_cuda_kernels_small(self): self.assertEqual(output_sentence, expected_output) + @slow def test_simple_generate_cuda_kernels_mid(self): expected_output = "Hello my name is John and I am a software engineer. I have" @@ -405,6 +407,7 @@ def test_simple_generate_cuda_kernels_mid(self): self.assertEqual(output_sentence, expected_output) + @slow def test_simple_generate_cuda_kernels_big(self): expected_output = "Hello my name is John. I am a student at the University of" From d06421a68e6c284b65a0009b1b70b5915352434e Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 07:33:04 +0100 Subject: [PATCH 067/116] styling --- tests/models/mamba/test_modeling_mamba.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index f5a9dc76ed840c..85516f7c2acb05 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -338,6 +338,7 @@ def test_model_from_pretrained(self): model = MambaModel.from_pretrained(model_name) self.assertIsNotNone(model) + @require_torch class MambaIntegrationTests(unittest.TestCase): def setUp(self): From 5fb80623a61cbd6c27b4a5739648f08f865bd679 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 09:07:28 +0100 Subject: [PATCH 068/116] chose forward wisely --- .../models/mamba/modeling_mamba.py | 57 ++++++++++--------- tests/models/mamba/test_modeling_mamba.py | 4 +- 2 files changed, 31 insertions(+), 30 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 5236a734c1a8c4..5d0a3af647a7e0 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -54,6 +54,7 @@ ) causal_conv1d_update, causal_conv1d_fn = None, None +is_fast_path_available = not any((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update)) is None _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m" _CONFIG_FOR_DOC = "MambaConfig" @@ -111,7 +112,7 @@ def __init__(self, config, layer_idx): self.D = nn.Parameter(torch.ones(self.intermediate_size)) # Keep in fp32 self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) - def forward(self, hidden_states: torch.Tensor, inference_params=None): + def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) hidden_states, gate = projected_states.chunk(2, dim=1) @@ -175,26 +176,7 @@ def forward(self, hidden_states: torch.Tensor, inference_params=None): return selected_states -class MambaCache: - def __init__(self, config, batch_size, dtype=torch.float16, device=None): - self.seqlen_offset = 0 - - intermediate_size = config.intermediate_size - ssm_state_size = config.state_size - conv_kernel_size = config.conv_kernel - - self.conv_states = { - i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) - for i in range(config.num_hidden_layers) - } - self.ssm_states = { - i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) - for i in range(config.num_hidden_layers) - } - - -class MambaMixerSlow(MambaMixer): - def forward(self, hidden_states, inference_params=None): + def slow_forward(self, hidden_states, inference_params=None): """ Compute ∆ A B C D, the state space parameters. @@ -262,6 +244,31 @@ def forward(self, hidden_states, inference_params=None): return selected_states + def forward(self, hidden_states, inference_params=None): + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: + return self.cuda_kernel_forward(hidden_states, inference_params) + return self.slow_forward(hidden_states, inference_params) + + +class MambaCache: + def __init__(self, config, batch_size, dtype=torch.float16, device=None): + self.seqlen_offset = 0 + + intermediate_size = config.intermediate_size + ssm_state_size = config.state_size + conv_kernel_size = config.conv_kernel + + self.conv_states = { + i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) + for i in range(config.num_hidden_layers) + } + self.ssm_states = { + i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) + for i in range(config.num_hidden_layers) + } + + + class MambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): @@ -287,13 +294,7 @@ def __init__(self, config, layer_idx): self.layer_idx = layer_idx self.residual_in_fp32 = config.residual_in_fp32 self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - - if any((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update)) is None: - MIXER_CLS = MambaMixerSlow - else: # CUDA is available and the kernels are also available - MIXER_CLS = MambaMixer - - self.mixer = MIXER_CLS(config, layer_idx=layer_idx) + self.mixer = MambaMixer(config, layer_idx=layer_idx) def forward(self, hidden_states, inference_params=None): residual = hidden_states diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 85516f7c2acb05..5ca5856e1443f4 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -360,9 +360,9 @@ def test_simple_generate(self): logits = model(input_ids=input_ids).logits - EXPECTED_LOGITS = torch.tensor([ -6.7070, -24.7656, -6.4766, -6.0078, -9.7812, -13.0703, -11.4688, -10.6562, -9.3359, -9.4766, -9.1719, -7.9102, -13.0469, -8.7266, -8.4297, -8.4766, -9.1094, -11.5234, -11.1250, -11.7812, -12.1562, -12.8359, -12.1797, -13.4062, -13.6406, -13.4141, -13.6562, -9.2344, -7.9805, -7.2188, -9.9219, -9.1719, -7.8438, -9.1250, -10.1094, -10.2344, -10.2266, -9.7578, -11.0000, -10.6406], device='cuda:0',dtype=torch.float16) # fmt: skip + EXPECTED_LOGITS = torch.tensor([ -6.7070, -24.7656, -6.4766, -6.0078, -9.7812, -13.0703, -11.4688, -10.6562, -9.3359, -9.4766, -9.1719, -7.9102, -13.0469, -8.7266, -8.4297, -8.4766, -9.1094, -11.5234, -11.1250, -11.7812, -12.1562, -12.8359, -12.1797, -13.4062, -13.6406, -13.4141, -13.6562, -9.2344, -7.9805, -7.2188, -9.9219, -9.1719, -7.8438, -9.1250, -10.1094, -10.2344, -10.2266, -9.7578, -11.0000, -10.6406],dtype=torch.float16) # fmt: skip - torch.testing.assert_allclose(logits[0, 0, :40], EXPECTED_LOGITS) + torch.testing.assert_allclose(logits[0, 0, :40].cpu(), EXPECTED_LOGITS) out = model.generate(input_ids, max_new_tokens=10) output_sentence = tokenizer.decode(out[0, :]) From edb4e91a7c34b961adb412c52ed97740b79dc5d7 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 09:12:25 +0100 Subject: [PATCH 069/116] nits --- src/transformers/models/mamba/modeling_mamba.py | 2 +- tests/models/mamba/test_modeling_mamba.py | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 5d0a3af647a7e0..e1f58463eb3042 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -246,7 +246,7 @@ def slow_forward(self, hidden_states, inference_params=None): def forward(self, hidden_states, inference_params=None): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: - return self.cuda_kernel_forward(hidden_states, inference_params) + return self.cuda_kernels_forward(hidden_states, inference_params) return self.slow_forward(hidden_states, inference_params) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 5ca5856e1443f4..162d8439427f31 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -360,7 +360,16 @@ def test_simple_generate(self): logits = model(input_ids=input_ids).logits - EXPECTED_LOGITS = torch.tensor([ -6.7070, -24.7656, -6.4766, -6.0078, -9.7812, -13.0703, -11.4688, -10.6562, -9.3359, -9.4766, -9.1719, -7.9102, -13.0469, -8.7266, -8.4297, -8.4766, -9.1094, -11.5234, -11.1250, -11.7812, -12.1562, -12.8359, -12.1797, -13.4062, -13.6406, -13.4141, -13.6562, -9.2344, -7.9805, -7.2188, -9.9219, -9.1719, -7.8438, -9.1250, -10.1094, -10.2344, -10.2266, -9.7578, -11.0000, -10.6406],dtype=torch.float16) # fmt: skip + EXPECTED_LOGITS = torch.tensor( + [ + -55.6875, -69.7500, -49.9062, -51.7500, -57.6250, -57.9062, -56.9375, + -57.9062, -54.6562, -55.9062, -55.3125, -58.0625, -60.5625, -47.0312, + -52.0312, -49.7812, -55.9375, -57.8750, -56.7500, -57.0938, -57.3438, + -58.2812, -57.7812, -58.7500, -59.5938, -59.0625, -58.6875, -52.9375, + -53.4688, -57.3438, -56.9062, -55.7188, -53.3125, -55.8125, -56.9688, + -56.9062, -56.1875, -54.7188, -56.4062, -57.4688 + ] + ,dtype=torch.float16) # fmt: skip torch.testing.assert_allclose(logits[0, 0, :40].cpu(), EXPECTED_LOGITS) From eb1fb640e39e005e931e8d580d21e9addbb7d8ac Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 09:30:30 +0100 Subject: [PATCH 070/116] update tests --- tests/models/mamba/test_modeling_mamba.py | 28 ++++++++++------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 162d8439427f31..644d3800dda9e3 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -358,29 +358,25 @@ def test_simple_generate(self): model.config.use_cache = True input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(torch_device) - logits = model(input_ids=input_ids).logits + with torch.no_grad(): + logits = model(input_ids=input_ids).logits - EXPECTED_LOGITS = torch.tensor( + EXPECTED_LOGITS_NO_GRAD = torch.tensor( [ - -55.6875, -69.7500, -49.9062, -51.7500, -57.6250, -57.9062, -56.9375, - -57.9062, -54.6562, -55.9062, -55.3125, -58.0625, -60.5625, -47.0312, - -52.0312, -49.7812, -55.9375, -57.8750, -56.7500, -57.0938, -57.3438, - -58.2812, -57.7812, -58.7500, -59.5938, -59.0625, -58.6875, -52.9375, - -53.4688, -57.3438, -56.9062, -55.7188, -53.3125, -55.8125, -56.9688, - -56.9062, -56.1875, -54.7188, -56.4062, -57.4688 + -55.6875, -69.8750, -49.9062, -51.7500, -57.6875, -57.9375, -56.9688, + -57.9375, -54.6875, -55.9375, -55.3125, -58.0938, -60.5625, -47.0000, + -52.0312, -49.7812, -55.9375, -57.9062, -56.7812, -57.1250, -57.3438, + -58.3125, -57.8125, -58.7812, -59.6250, -59.0938, -58.7188, -52.9375, + -53.4688, -57.3750, -56.9375, -55.7500, -53.3125, -55.8438, -57.0000, + -56.9062, -56.2188, -54.7188, -56.4375, -57.5000 ] ,dtype=torch.float16) # fmt: skip - torch.testing.assert_allclose(logits[0, 0, :40].cpu(), EXPECTED_LOGITS) + torch.testing.assert_allclose(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD) - out = model.generate(input_ids, max_new_tokens=10) + out = model.generate(input_ids, do_sample=False, max_new_tokens=10) output_sentence = tokenizer.decode(out[0, :]) - self.assertEqual( - output_sentence, - [ - "Hey how are you doing?\n\nI'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm so glad you're here. I'm" - ], - ) + self.assertEqual(output_sentence,"Hey how are you doing?\n\nA:\n\nI have a similar") def test_simple_generate_cuda_kernels_tiny(self): expected_output = "Hello my name is John of the Golden, and I am the Lord" From de4fe46e5c2b5875da37a00ad01e1f5e978c9680 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 09:55:38 +0100 Subject: [PATCH 071/116] fix gradient checkpointing --- src/transformers/models/mamba/modeling_mamba.py | 14 ++++++-------- tests/models/mamba/test_modeling_mamba.py | 2 +- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index e1f58463eb3042..bf7608036f94fd 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -200,14 +200,14 @@ def slow_forward(self, hidden_states, inference_params=None): # 2. Convolution sequence transformation if inference_params.seqlen_offset > 0: conv_state = inference_params.conv_states[self.layer_idx] - conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1)) - conv_state[:, :, -1].copy_(hidden_states[:, :, 0]) + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, -1] = hidden_states[:, :, 0] hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + self.conv1d.bias) hidden_states = hidden_states.unsqueeze(-1) else: - inference_params.conv_states[self.layer_idx].copy_( + inference_params.conv_states[self.layer_idx] = ( nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) @@ -230,17 +230,17 @@ def slow_forward(self, hidden_states, inference_params=None): ssm_state = inference_params.ssm_states[self.layer_idx] ys = [] for i in range(seq_len): - ssm_state.copy_(ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :]) + ssm_state = (ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :]) # [b, d, n] X [b, n] -> [b, d] y = torch.matmul(ssm_state, C[:, i, :].unsqueeze(-1)) ys.append(y[:, :, 0]) y = torch.stack(ys, dim=-1) # shape (b, l, d) - y = y + (hidden_states * self.D.to(hidden_states.dtype)[None, :, None]) + y = y + (hidden_states * self.D[None, :, None]) y = y * self.act(gate) # (B D) # 4. Final linear projection selected_states = self.out_proj(y.transpose(1, 2)) - inference_params.ssm_states[self.layer_idx].copy_(ssm_state) + inference_params.ssm_states[self.layer_idx] = ssm_state return selected_states @@ -268,8 +268,6 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None): } - - class MambaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 644d3800dda9e3..1227280bd11649 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -372,7 +372,7 @@ def test_simple_generate(self): ] ,dtype=torch.float16) # fmt: skip - torch.testing.assert_allclose(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD) + torch.testing.assert_close(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD) out = model.generate(input_ids, do_sample=False, max_new_tokens=10) output_sentence = tokenizer.decode(out[0, :]) From 54ffaa3e7a8b07d950c8ab077b82528beaf593b8 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 10:08:48 +0100 Subject: [PATCH 072/116] fixup --- src/transformers/models/mamba/modeling_mamba.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index bf7608036f94fd..41f66ecd9beee8 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -54,7 +54,9 @@ ) causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = not any((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update)) is None +is_fast_path_available = ( + any((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update)) is not None +) _CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m" _CONFIG_FOR_DOC = "MambaConfig" @@ -175,7 +177,6 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non selected_states = self.out_proj(y.transpose(1, 2)) return selected_states - def slow_forward(self, hidden_states, inference_params=None): """ @@ -207,8 +208,8 @@ def slow_forward(self, hidden_states, inference_params=None): hidden_states = hidden_states.unsqueeze(-1) else: - inference_params.conv_states[self.layer_idx] = ( - nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + inference_params.conv_states[self.layer_idx] = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) @@ -230,7 +231,7 @@ def slow_forward(self, hidden_states, inference_params=None): ssm_state = inference_params.ssm_states[self.layer_idx] ys = [] for i in range(seq_len): - ssm_state = (ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :]) + ssm_state = ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :] # [b, d, n] X [b, n] -> [b, d] y = torch.matmul(ssm_state, C[:, i, :].unsqueeze(-1)) ys.append(y[:, :, 0]) From 977d34f09173b334e02e2a90789f4c21eb1fd474 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 10:08:53 +0100 Subject: [PATCH 073/116] nit --- tests/models/mamba/test_modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 1227280bd11649..ea4badd4c1d86c 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -376,7 +376,7 @@ def test_simple_generate(self): out = model.generate(input_ids, do_sample=False, max_new_tokens=10) output_sentence = tokenizer.decode(out[0, :]) - self.assertEqual(output_sentence,"Hey how are you doing?\n\nA:\n\nI have a similar") + self.assertEqual(output_sentence, "Hey how are you doing?\n\nA:\n\nI have a similar") def test_simple_generate_cuda_kernels_tiny(self): expected_output = "Hello my name is John of the Golden, and I am the Lord" From 0928453b3854380f7ff1784777f4424a2f930111 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 10:09:53 +0100 Subject: [PATCH 074/116] fix doc --- src/transformers/models/mamba/modeling_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 41f66ecd9beee8..0673b918ca0810 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -58,11 +58,11 @@ any((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update)) is not None ) -_CHECKPOINT_FOR_DOC = "state-spaces/mamba-130m" +_CHECKPOINT_FOR_DOC = "ArthurZ/mamba-130m" _CONFIG_FOR_DOC = "MambaConfig" MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "state-spaces/mamba-130m", + "ArthurZ/mamba-130m", "state-spaces/mamba-370m", "state-spaces/mamba-790m", "state-spaces/mamba-1.4b", From 2c90536a373421d37b189d3de60ed916ca4386f1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 10:22:26 +0100 Subject: [PATCH 075/116] check copies --- src/transformers/models/mamba/configuration_mamba.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 198981a87663a8..de0494f4afd6d2 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -57,16 +57,14 @@ class MambaConfig(PretrainedConfig): The id of the end of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as GPTNeoX. expand (``, *optional*, defaults to 2): - dt_rank (``, *optional*, defaults to `"auto"`): - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether or not to tie the word embeddings with the input token embeddings. + time_step_rank (``, *optional*, defaults to `"auto"`): use_bias (``, *optional*, defaults to `False`): use_conv_bias (``, *optional*, defaults to `True`): hidden_act (``, *optional*, defaults to `"silu"`): initializer_range (``, *optional*, defaults to 0.1): residual_in_fp32 (`bool`, *optional*, defaults to `False`): Whether or not residuals should be in `float32`. - use_cache (`bool`, *optional*, defaults to `False`): + use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. From 4ba9c7925493cafdc8f4406bae3b621fa56a27da Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 10:34:48 +0100 Subject: [PATCH 076/116] fix the docstring --- .../models/mamba/configuration_mamba.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index de0494f4afd6d2..c214aef1cee2a1 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -44,24 +44,29 @@ class MambaConfig(PretrainedConfig): `inputs_ids` passed when calling [`MambaModel`]. hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the embeddings and hidden states. - state_size (``, *optional*, defaults to 16): + state_size (`int`, *optional*, defaults to 16): shape of the state space latents. num_hidden_layers (`int`, *optional*, defaults to 32): Number of hidden layers in the model. layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): The epsilon to use in the layer normalization layers. - pad_token_id (``, *optional*, defaults to 0): + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. bos_token_id (`int`, *optional*, defaults to 1): The id of the beginning of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as GPTNeoX. eos_token_id (`int`, *optional*, defaults to 2): The id of the end of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as GPTNeoX. - expand (``, *optional*, defaults to 2): - time_step_rank (``, *optional*, defaults to `"auto"`): - use_bias (``, *optional*, defaults to `False`): - use_conv_bias (``, *optional*, defaults to `True`): - hidden_act (``, *optional*, defaults to `"silu"`): - initializer_range (``, *optional*, defaults to 0.1): + expand (`int`, *optional*, defaults to 2): Expanding factor used to determin the intermediate size. + time_step_rank (`int`, *optional*, defaults to `"auto"`): rank fo the discretization projection matrix. + use_bias (`bool`, *optional*, defaults to `False`): + Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block + use_conv_bias (`bool`, *optional*, defaults to `True`): + Whether or not to use bias in the convolution layer of the mixer block. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + initializer_range (`float`, *optional*, defaults to 0.1): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. residual_in_fp32 (`bool`, *optional*, defaults to `False`): Whether or not residuals should be in `float32`. use_cache (`bool`, *optional*, defaults to `True`): @@ -84,7 +89,6 @@ class MambaConfig(PretrainedConfig): ```""" model_type = "mamba" - attribute_map = {"max_position_embeddings": "context_length"} def __init__( self, From 3651dbad1bcef3afbab3e7693dc2117940becd06 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 10:46:04 +0100 Subject: [PATCH 077/116] fix some more tests --- src/transformers/models/mamba/configuration_mamba.py | 2 +- src/transformers/models/mamba/modeling_mamba.py | 1 + tests/models/mamba/test_modeling_mamba.py | 5 +++-- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index c214aef1cee2a1..156d37df7af203 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -129,4 +129,4 @@ def __init__( self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache - super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 0673b918ca0810..64b8230aea6c7e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -503,6 +503,7 @@ def forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it ) -> Union[Tuple, MambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index ea4badd4c1d86c..2950def416adc9 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -222,9 +222,10 @@ def prepare_config_and_inputs_for_common(self): @require_torch class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () - fx_compatible = True + fx_compatible = False # FIXME let's try to support this @ArthurZucker + test_torchscript = False # FIXME let's try to support this @ArthurZucker test_missing_keys = False - test_model_parallel = True + test_model_parallel = False test_pruning = False test_head_masking = False # Mamba does not have attention heads test_model_parallel = False From 426e6f39427562e9b85450d5fed4e0087153b21a Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Thu, 29 Feb 2024 10:47:42 +0100 Subject: [PATCH 078/116] style --- src/transformers/models/mamba/modeling_mamba.py | 2 +- tests/models/mamba/test_modeling_mamba.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 64b8230aea6c7e..6b35109f3b522c 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -503,7 +503,7 @@ def forward( use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it + **kwargs, # `attention_mask` is passed by the tokenizer and we don't want it ) -> Union[Tuple, MambaOutput]: output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 2950def416adc9..30bb7efad68109 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -222,8 +222,8 @@ def prepare_config_and_inputs_for_common(self): @require_torch class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (MambaModel, MambaForCausalLM) if is_torch_available() else () - fx_compatible = False # FIXME let's try to support this @ArthurZucker - test_torchscript = False # FIXME let's try to support this @ArthurZucker + fx_compatible = False # FIXME let's try to support this @ArthurZucker + test_torchscript = False # FIXME let's try to support this @ArthurZucker test_missing_keys = False test_model_parallel = False test_pruning = False From 951b1aabf2df6889d74493333276ab41c0be85e8 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 09:55:25 +0900 Subject: [PATCH 079/116] fix beam search --- src/transformers/generation/utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5b7d18e06c1d10..97707d3bdf79cb 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3152,7 +3152,7 @@ def beam_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - if model_kwargs["past_key_values"] is not None: + if model_kwargs.get("past_key_values",None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx ) @@ -3499,7 +3499,7 @@ def beam_sample( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - if model_kwargs["past_key_values"] is not None: + if model_kwargs.get("past_key_values",None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx ) @@ -3898,7 +3898,7 @@ def group_beam_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - if model_kwargs["past_key_values"] is not None: + if model_kwargs.get("past_key_values",None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], reordering_indices ) @@ -4250,7 +4250,7 @@ def constrained_beam_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - if model_kwargs["past_key_values"] is not None: + if model_kwargs.get("past_key_values",None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx ) From 4101369cb9d14227018709f389c0b0477056772b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 09:58:39 +0100 Subject: [PATCH 080/116] add init schene --- src/transformers/generation/utils.py | 8 +- .../models/mamba/modeling_mamba.py | 92 +++++++++---------- 2 files changed, 49 insertions(+), 51 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 97707d3bdf79cb..84386701489951 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -3152,7 +3152,7 @@ def beam_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - if model_kwargs.get("past_key_values",None) is not None: + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx ) @@ -3499,7 +3499,7 @@ def beam_sample( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - if model_kwargs.get("past_key_values",None) is not None: + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx ) @@ -3898,7 +3898,7 @@ def group_beam_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - if model_kwargs.get("past_key_values",None) is not None: + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], reordering_indices ) @@ -4250,7 +4250,7 @@ def constrained_beam_search( model_kwargs = self._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder, model_inputs=model_inputs ) - if model_kwargs.get("past_key_values",None) is not None: + if model_kwargs.get("past_key_values", None) is not None: model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"], beam_idx ) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 6b35109f3b522c..3f7b4f5a0e47ec 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -15,6 +15,7 @@ """PyTorch MAMBA model.""" from dataclasses import dataclass +import math from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -320,53 +321,50 @@ class MambaPreTrainedModel(PreTrainedModel): def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, MambaMixer): - pass - if isinstance(module, nn.Linear): - if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=self.config.initializer_range) - - # TODO make sure we properly init - # self.A_log._no_weight_decay = True - # self.D._no_weight_decay = True - # - # # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max - # dt = torch.exp( - # torch.rand(self.intermediate_size, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) - # + math.log(dt_min) - # ).clamp(min=dt_init_floor) - # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 - # inv_dt = dt + torch.log(-torch.expm1(-dt)) - # with torch.no_grad(): - # self.dt_proj.bias.copy_(inv_dt) - # # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit - # self.dt_proj.bias._no_reinit = True - - # if isinstance(module, nn.Linear): - # if module.bias is not None: - # if not getattr(module.bias, "_no_reinit", False): - # nn.init.zeros_(module.bias) - # elif isinstance(module, nn.Embedding): - # nn.init.normal_(module.weight, std=initializer_range) - # - # if rescale_prenorm_residual: - # # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # # - # # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - # for name, p in module.named_parameters(): - # if name in ["out_proj.weight", "fc2.weight"]: - # # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # # We need to reinit p since this code could be called multiple times - # # Having just p *= scale would repeatedly scale it down - # nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - # with torch.no_grad(): - # p /= math.sqrt(n_residuals_per_layer * n_layer) + module.A_log._no_weight_decay = True + module.D._no_weight_decay = True + + dt_init_std = self.dt_rank**-0.5 * self.config.dt_scale + if self.config.dt_init == "constant": + nn.init.constant_(self.dt_proj.weight, dt_init_std) + elif self.config.dt_init == "random": + nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + + dt = torch.exp( + torch.rand(self.intermediate_size) * (math.log(self.config.dt_max) - math.log(self.config.dt_min)) + + math.log(self.config.dt_min) + ).clamp(min=self.config.dt_init_floor) + # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 + inv_dt = dt + torch.log(-torch.expm1(-dt)) + with torch.no_grad(): + self.dt_proj.bias.copy_(inv_dt) + self.dt_proj.bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * self.config.num_layers) + + @dataclass From 65db96bd6ef8623369286326079d219d8217088f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 10:17:48 +0100 Subject: [PATCH 081/116] update --- .../models/mamba/configuration_mamba.py | 25 ++++++++++++++--- .../models/mamba/modeling_mamba.py | 27 +++++++++---------- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 156d37df7af203..8f2326d38205fa 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -58,7 +58,7 @@ class MambaConfig(PretrainedConfig): The id of the end of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as GPTNeoX. expand (`int`, *optional*, defaults to 2): Expanding factor used to determin the intermediate size. - time_step_rank (`int`, *optional*, defaults to `"auto"`): rank fo the discretization projection matrix. + conv_kernel (``, *optional*, defaults to 4): use_bias (`bool`, *optional*, defaults to `False`): Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block use_conv_bias (`bool`, *optional*, defaults to `True`): @@ -69,6 +69,12 @@ class MambaConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. residual_in_fp32 (`bool`, *optional*, defaults to `False`): Whether or not residuals should be in `float32`. + time_step_rank (`int`, *optional*, defaults to `"auto"`): rank fo the discretization projection matrix. + time_step_scale (``, *optional*, defaults to 0): + time_step_init_floor (``, *optional*, defaults to 0): + time_step_min (``, *optional*, defaults to 0): + time_step_max (``, *optional*, defaults to 0): + rescale_prenorm_residual (``, *optional*, defaults to `False`): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. @@ -101,12 +107,18 @@ def __init__( bos_token_id=1, eos_token_id=2, expand=2, - time_step_rank="auto", + conv_kernel=4, use_bias=False, use_conv_bias=True, hidden_act="silu", initializer_range=0.1, residual_in_fp32=False, + time_step_rank="auto", + time_step_scale=0, + time_step_init_floor=0, + time_step_min=0, + time_step_max=0, + rescale_prenorm_residual=False, use_cache=True, **kwargs, ): @@ -115,10 +127,9 @@ def __init__( self.state_size = state_size self.num_hidden_layers = num_hidden_layers self.layer_norm_epsilon = layer_norm_epsilon - self.conv_kernel = 4 + self.conv_kernel = conv_kernel self.expand = expand self.intermediate_size = int(self.expand * self.hidden_size) - self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id @@ -126,6 +137,12 @@ def __init__( self.use_conv_bias = use_conv_bias self.hidden_act = hidden_act self.initializer_range = initializer_range + self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank + self.time_step_scale = time_step_scale + self.time_step_init_floor = time_step_init_floor + self.time_step_min = time_step_min + self.time_step_max = time_step_max + self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 3f7b4f5a0e47ec..30b50b7597764f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -43,7 +43,7 @@ from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: logger.warning_once( - " `mamba_ssm` is not installed in your environnement. Make sure to install it following `src/transformers/kernels/mamba/Makefile`" + "The `mamba_ssm` package is not installed in your environnement. Make sure to install it if you want to use the custom cuda kernels" ) selective_state_update, selective_scan_fn = None, None @@ -51,7 +51,7 @@ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: logger.warning_once( - " `causal_conv1d` is not installed in your environnement. Make sure to install it: `src/transformers/kernels/mamba/Makefile`" + "The `causal_conv1d` package is not installed in your environnement. Make sure to install it if you want to use the custom cuda kernels" ) causal_conv1d_update, causal_conv1d_fn = None, None @@ -78,7 +78,6 @@ def __init__(self, config, layer_idx): self.hidden_size = config.hidden_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel - self.expand = config.expand self.intermediate_size = config.intermediate_size self.time_step_rank = config.time_step_rank self.layer_idx = layer_idx @@ -107,10 +106,8 @@ def __init__(self, config, layer_idx): A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(self.intermediate_size, -1).contiguous() - self.A_log = nn.Parameter( - torch.log(A) - ) # TODO this parameter should be kept in float32. We don't have support for that I think - + # TODO this parameter should be kept in float32. We don't have support for in _keep_in_float32 + self.A_log = nn.Parameter(torch.log(A)) # D "skip" parameter self.D = nn.Parameter(torch.ones(self.intermediate_size)) # Keep in fp32 self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) @@ -135,8 +132,8 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - x_dbl = self.x_proj(hidden_states.transpose(1, 2)) # TODO find a better name for this one - time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) # TODO find a better name for this one + time_step, B, C = torch.split(ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) @@ -216,8 +213,8 @@ def slow_forward(self, hidden_states, inference_params=None): # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - x_dbl = self.x_proj(hidden_states.transpose(1, 2)) - time_step, B, C = torch.split(x_dbl, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split(ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) discrete_time_step = self.dt_proj(time_step) A = -torch.exp(self.A_log.float()) # (intermediate_size, ssm_state_size) @@ -324,15 +321,15 @@ def _init_weights(self, module): module.A_log._no_weight_decay = True module.D._no_weight_decay = True - dt_init_std = self.dt_rank**-0.5 * self.config.dt_scale + dt_init_std = self.time_step_rank**-0.5 * self.config.time_step_scale if self.config.dt_init == "constant": nn.init.constant_(self.dt_proj.weight, dt_init_std) elif self.config.dt_init == "random": nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) dt = torch.exp( - torch.rand(self.intermediate_size) * (math.log(self.config.dt_max) - math.log(self.config.dt_min)) - + math.log(self.config.dt_min) + torch.rand(self.intermediate_size) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + + math.log(self.config.time_step_min) ).clamp(min=self.config.dt_init_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) @@ -362,7 +359,7 @@ def _init_weights(self, module): # Having just p *= scale would repeatedly scale it down nn.init.kaiming_uniform_(p, a=math.sqrt(5)) with torch.no_grad(): - p /= math.sqrt(n_residuals_per_layer * self.config.num_layers) + p /= math.sqrt(self.config.num_layers) From 0f3dfc71d9e1a472c6a08fbc3c3b2b42558cebf2 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 10:18:39 +0100 Subject: [PATCH 082/116] nit --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 30b50b7597764f..4393969e307627 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -300,7 +300,7 @@ def forward(self, hidden_states, inference_params=None): residual = residual.to(torch.float32) hidden_states = self.mixer(hidden_states, inference_params=inference_params) - hidden_states = residual.to(torch.float32) + hidden_states + hidden_states = residual + hidden_states return hidden_states From f8bd0aa8aa78b2921c2fecd328443b9c5f6c4b7b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 10:28:56 +0100 Subject: [PATCH 083/116] fix --- .../models/mamba/configuration_mamba.py | 12 ++-- .../models/mamba/modeling_mamba.py | 66 +++++++++---------- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 8f2326d38205fa..1861b07e1f3c21 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -114,10 +114,11 @@ def __init__( initializer_range=0.1, residual_in_fp32=False, time_step_rank="auto", - time_step_scale=0, - time_step_init_floor=0, - time_step_min=0, - time_step_max=0, + time_step_scale=1.0, + time_step_min=0.001, + time_step_max=.1, + time_step_init_scheme="random", + time_step_floor=1e-4, rescale_prenorm_residual=False, use_cache=True, **kwargs, @@ -139,9 +140,10 @@ def __init__( self.initializer_range = initializer_range self.time_step_rank = math.ceil(self.hidden_size / 16) if time_step_rank == "auto" else time_step_rank self.time_step_scale = time_step_scale - self.time_step_init_floor = time_step_init_floor self.time_step_min = time_step_min self.time_step_max = time_step_max + self.time_step_init_scheme = time_step_init_scheme + self.time_step_floor = time_step_floor self.rescale_prenorm_residual = rescale_prenorm_residual self.residual_in_fp32 = residual_in_fp32 self.use_cache = use_cache diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 4393969e307627..097338b7041414 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -321,45 +321,45 @@ def _init_weights(self, module): module.A_log._no_weight_decay = True module.D._no_weight_decay = True - dt_init_std = self.time_step_rank**-0.5 * self.config.time_step_scale - if self.config.dt_init == "constant": - nn.init.constant_(self.dt_proj.weight, dt_init_std) - elif self.config.dt_init == "random": - nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std) + dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale + if self.config.time_step_init_scheme == "constant": + nn.init.constant_(module.dt_proj.weight, dt_init_std) + elif self.config.time_step_init_scheme == "random": + nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) dt = torch.exp( - torch.rand(self.intermediate_size) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(self.config.intermediate_size) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) - ).clamp(min=self.config.dt_init_floor) + ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): - self.dt_proj.bias.copy_(inv_dt) - self.dt_proj.bias._no_reinit = True - - if isinstance(module, nn.Linear): - if module.bias is not None: - if not getattr(module.bias, "_no_reinit", False): - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=self.config.initializer_range) - - if self.config.rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight", "fc2.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) - # We need to reinit p since this code could be called multiple times - # Having just p *= scale would repeatedly scale it down - nn.init.kaiming_uniform_(p, a=math.sqrt(5)) - with torch.no_grad(): - p /= math.sqrt(self.config.num_layers) + module.dt_proj.bias.copy_(inv_dt) + module.dt_proj.bias._no_reinit = True + + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=self.config.initializer_range) + + if self.config.rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(self.config.num_layers) From b2bd0c78457eda42c3b5d200a6b558e4b74018a1 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 10:39:36 +0100 Subject: [PATCH 084/116] fixup the doc --- .../models/mamba/configuration_mamba.py | 14 +++++++------- src/transformers/models/mamba/modeling_mamba.py | 17 ++++++++++------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 1861b07e1f3c21..1e2c41d3380a6c 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -70,10 +70,11 @@ class MambaConfig(PretrainedConfig): residual_in_fp32 (`bool`, *optional*, defaults to `False`): Whether or not residuals should be in `float32`. time_step_rank (`int`, *optional*, defaults to `"auto"`): rank fo the discretization projection matrix. - time_step_scale (``, *optional*, defaults to 0): - time_step_init_floor (``, *optional*, defaults to 0): - time_step_min (``, *optional*, defaults to 0): - time_step_max (``, *optional*, defaults to 0): + time_step_scale (``, *optional*, defaults to 1.0): + time_step_min (``, *optional*, defaults to 0.001): + time_step_max (``, *optional*, defaults to 0.1): + time_step_init_scheme (``, *optional*, defaults to `"random"`): + time_step_floor (``, *optional*, defaults to 0.0001): rescale_prenorm_residual (``, *optional*, defaults to `False`): use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. @@ -116,7 +117,7 @@ def __init__( time_step_rank="auto", time_step_scale=1.0, time_step_min=0.001, - time_step_max=.1, + time_step_max=0.1, time_step_init_scheme="random", time_step_floor=1e-4, rescale_prenorm_residual=False, @@ -129,8 +130,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.layer_norm_epsilon = layer_norm_epsilon self.conv_kernel = conv_kernel - self.expand = expand - self.intermediate_size = int(self.expand * self.hidden_size) + self.intermediate_size = int(expand * self.hidden_size) self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 097338b7041414..ec72c9969f6284 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -14,8 +14,8 @@ # limitations under the License. """PyTorch MAMBA model.""" -from dataclasses import dataclass import math +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -133,7 +133,9 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) # TODO find a better name for this one - time_step, B, C = torch.split(ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) @@ -214,7 +216,9 @@ def slow_forward(self, hidden_states, inference_params=None): # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) - time_step, B, C = torch.split(ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) discrete_time_step = self.dt_proj(time_step) A = -torch.exp(self.A_log.float()) # (intermediate_size, ssm_state_size) @@ -319,7 +323,7 @@ def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, MambaMixer): module.A_log._no_weight_decay = True - module.D._no_weight_decay = True + module.D._no_weight_decay = True dt_init_std = self.config.time_step_rank**-0.5 * self.config.time_step_scale if self.config.time_step_init_scheme == "constant": @@ -328,7 +332,8 @@ def _init_weights(self, module): nn.init.uniform_(module.dt_proj.weight, -dt_init_std, dt_init_std) dt = torch.exp( - torch.rand(self.config.intermediate_size) * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + torch.rand(self.config.intermediate_size) + * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min)) + math.log(self.config.time_step_min) ).clamp(min=self.config.time_step_floor) # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 @@ -362,8 +367,6 @@ def _init_weights(self, module): p /= math.sqrt(self.config.num_layers) - - @dataclass class MambaOutput(ModelOutput): """ From cf5852941dd9335838680454730b635a7caadaa0 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 10:43:15 +0100 Subject: [PATCH 085/116] fix the doc --- src/transformers/models/mamba/modeling_mamba.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index ec72c9969f6284..8ea13a4f9bc833 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -437,11 +437,9 @@ class MambaCausalLMOutput(ModelOutput): MAMBA_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): - `input_ids_length` = `sequence_length` if `past_key_values` is `None` else - `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input - sequence tokens in the vocabulary. + Indices of input sequence tokens in the vocabulary. - If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + If `inference_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -452,11 +450,11 @@ class MambaCausalLMOutput(ModelOutput): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - inference_params (tuple of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`, *optional*): + inference_params (`MambaCache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): - If set to `True`, the last state is returned and can be used to quickly generate the next logits. + If set to `True`, the `inference_params` is returned and can be used to quickly generate the next logits. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. From e9c34472a35d1204e68958f752cfbc9d8f9bdb00 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 10:48:36 +0100 Subject: [PATCH 086/116] fixup --- src/transformers/models/mamba/configuration_mamba.py | 1 + src/transformers/models/mamba/modeling_mamba.py | 9 ++------- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 1e2c41d3380a6c..6228a165290331 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -130,6 +130,7 @@ def __init__( self.num_hidden_layers = num_hidden_layers self.layer_norm_epsilon = layer_norm_epsilon self.conv_kernel = conv_kernel + self.expand = expand self.intermediate_size = int(expand * self.hidden_size) self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 8ea13a4f9bc833..35922182cbc680 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -566,7 +566,7 @@ def forward( MAMBA_START_DOCSTRING, ) class MambaForCausalLM(MambaPreTrainedModel): - _tied_weights_keys = ["head.weight"] + _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) @@ -588,12 +588,7 @@ def set_input_embeddings(self, new_embeddings): return self.backbone.set_input_embeddings(new_embeddings) def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, - **kwargs, + self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs ) -> Dict[str, Any]: model_kwargs["inference_params"] = outputs["inference_params"] return model_kwargs From 1282a75fdf2e440d152f8ed0c98bb23610d800ab Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 12:21:18 +0100 Subject: [PATCH 087/116] tentative update but slow is no longer good --- .../models/mamba/configuration_mamba.py | 2 +- .../models/mamba/modeling_mamba.py | 105 +++++++++--------- tests/models/mamba/test_modeling_mamba.py | 2 +- 3 files changed, 52 insertions(+), 57 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 6228a165290331..6e0a9b442d1e7f 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -113,7 +113,7 @@ def __init__( use_conv_bias=True, hidden_act="silu", initializer_range=0.1, - residual_in_fp32=False, + residual_in_fp32=True, time_step_rank="auto", time_step_scale=1.0, time_step_min=0.001, diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 35922182cbc680..42666e7f042d35 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -73,6 +73,16 @@ class MambaMixer(nn.Module): + """ + Selective layer TODO DOC DOC DOC + Compute ∆ A B C D, the state space parameters. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) + + ∆, B and C are the `selective` parameters + """ + def __init__(self, config, layer_idx): super().__init__() self.hidden_size = config.hidden_size @@ -102,14 +112,12 @@ def __init__(self, config, layer_idx): self.dt_proj = nn.Linear(self.time_step_rank, self.intermediate_size, bias=True) # S4D real initialization. These are not discretized! - # THe core is to load them, compute the discrete states, then write the updates state. Keeps the memory bounded + # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(self.intermediate_size, -1).contiguous() - # TODO this parameter should be kept in float32. We don't have support for in _keep_in_float32 self.A_log = nn.Parameter(torch.log(A)) - # D "skip" parameter - self.D = nn.Parameter(torch.ones(self.intermediate_size)) # Keep in fp32 + self.D = nn.Parameter(torch.ones(self.intermediate_size)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=None): @@ -132,20 +140,18 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) # TODO find a better name for this one + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) time_step, B, C = torch.split( ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) - # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) A = -torch.exp(self.A_log.float()) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None if inference_params is not None and inference_params.seqlen_offset > 0: - y = selective_state_update( + scan_outputs = selective_state_update( ssm_state, hidden_states[..., 0], discrete_time_step[..., 0], @@ -158,7 +164,7 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non dt_softplus=True, ).unsqueeze(-1) else: - y, last_state = selective_scan_fn( + scan_outputs, last_state = selective_scan_fn( hidden_states, discrete_time_step, A, @@ -174,82 +180,71 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non inference_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection - selected_states = self.out_proj(y.transpose(1, 2)) + selected_states = self.out_proj(scan_outputs.transpose(1, 2)) return selected_states def slow_forward(self, hidden_states, inference_params=None): - """ - - Compute ∆ A B C D, the state space parameters. - A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) - ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, - and is why Mamba is called **selective** state spaces) - - Args: - hidden_states: - inference_params: - - Returns: - - """ - batch_size, seq_len, _ = hidden_states.shape + _, seq_len, _ = hidden_states.shape # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states).transpose(1, 2) + projected_states = self.in_proj(hidden_states).transpose(1, 2) # (batch, 2 * intermediate_size, seq_len) hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation if inference_params.seqlen_offset > 0: - conv_state = inference_params.conv_states[self.layer_idx] + conv_state = inference_params.conv_states[self.layer_idx] # (batch, intermediate_size, conv_kernel_size) conv_state = torch.roll(conv_state, shifts=-1, dims=-1) conv_state[:, :, -1] = hidden_states[:, :, 0] - - hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + self.conv1d.bias) - hidden_states = hidden_states.unsqueeze(-1) - + bias = getattr(self.conv1d, "bias", 0.0) + hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + bias) + hidden_states = hidden_states.unsqueeze(-1) # (batch, intermediate_size, 1) else: inference_params.conv_states[self.layer_idx] = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) - hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (batch, intermediate_size, seq_len) # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + # 3.a. Selection: (batch, seq_len, self.time_step_rank + self.ssm_state_size * 2) + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) time_step, B, C = torch.split( ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) - discrete_time_step = self.dt_proj(time_step) + discrete_time_step = self.dt_proj(time_step) # (batch, seq_len, intermediate_size) + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) + + # 3.b. Discretization: B and C to [batch_size, seq_len, intermediate_size, ssm_state_size] (SRAM) A = -torch.exp(self.A_log.float()) # (intermediate_size, ssm_state_size) + # [batch_size, intermediate_size, seq_len, 1] X [1, intermediate_size, 1, ssm_state_size] + discrete_A = torch.exp(discrete_time_step[:, :, :, None] * A[None, :, None, :]) + # [batch_size, intermediate_size, seq_len, 1] X [batch_size, 1, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() + # [batch_size, intermediade_size, seq_len, 1] X [batch_size, seq_len, ssm_state_size, 1] + deltaB_u = discrete_B * hidden_states[:, :, :, None] - # 3.b. discretize time_step, B and C: zero-order hold from (B,L,D) to (B,L,D,N) - discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) - # [batch_size, d, l, 1] X [1, d, 1, n] -> [batch_size, d, l, n] - dA = torch.exp(discrete_time_step[:, :, :, None] * A[None, :, None, :]) - # [batch_size, d, l, 1] [b, d, l, 1] -> [batch_size, d, l, 1] X [batch_size, 1, l, n] -> [batch_size, d, l, n] - deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :] + deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] - ys = [] + scan_outputs = [] for i in range(seq_len): - ssm_state = ssm_state * dA[:, :, i, :] + deltaB_u[:, :, i, :] - # [b, d, n] X [b, n] -> [b, d] - y = torch.matmul(ssm_state, C[:, i, :].unsqueeze(-1)) - ys.append(y[:, :, 0]) - y = torch.stack(ys, dim=-1) # shape (b, l, d) - y = y + (hidden_states * self.D[None, :, None]) - y = y * self.act(gate) # (B D) + ssm_state = ssm_state * discrete_A[:, :, i, :] + deltaB_u[:, :, i, :] + # [batch_size, intermediade_size, ssm_state] X [batch_size, ssm_state] -> [batch_size, intermediade_size] + scan_output = torch.matmul(ssm_state, C[:, i, :].unsqueeze(-1).float()) + scan_outputs.append(scan_output[:, :, 0]) + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] + scan_output = scan_output + (hidden_states * self.D[None, :, None].float()) + contextualized_states = (scan_output * self.act(gate)).to(hidden_states.dtype) # [batch, intermediade_size, seq_len] # 4. Final linear projection - selected_states = self.out_proj(y.transpose(1, 2)) + contextualized_states = self.out_proj(contextualized_states.transpose(1, 2)) # [batch, seq_len, hidden_size] inference_params.ssm_states[self.layer_idx] = ssm_state - return selected_states + return contextualized_states def forward(self, hidden_states, inference_params=None): - if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: - return self.cuda_kernels_forward(hidden_states, inference_params) + # if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: + # return self.cuda_kernels_forward(hidden_states, inference_params) return self.slow_forward(hidden_states, inference_params) @@ -300,7 +295,7 @@ def __init__(self, config, layer_idx): def forward(self, hidden_states, inference_params=None): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: + if self.residual_in_fp32 or True: residual = residual.to(torch.float32) hidden_states = self.mixer(hidden_states, inference_params=inference_params) @@ -641,7 +636,7 @@ def forward( ) hidden_states = mamba_outputs[0] - logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)) + logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() loss = None if labels is not None: diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 30bb7efad68109..3046670dc3cbd8 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -371,7 +371,7 @@ def test_simple_generate(self): -53.4688, -57.3750, -56.9375, -55.7500, -53.3125, -55.8438, -57.0000, -56.9062, -56.2188, -54.7188, -56.4375, -57.5000 ] - ,dtype=torch.float16) # fmt: skip + ,dtype=torch.float32) # fmt: skip torch.testing.assert_close(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD) From fa561b26840fc2973855a0774e757de3bfa5820d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 13:18:20 +0100 Subject: [PATCH 088/116] nit --- src/transformers/models/mamba/modeling_mamba.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 42666e7f042d35..146c4cf19e01c7 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -195,6 +195,8 @@ def slow_forward(self, hidden_states, inference_params=None): conv_state = inference_params.conv_states[self.layer_idx] # (batch, intermediate_size, conv_kernel_size) conv_state = torch.roll(conv_state, shifts=-1, dims=-1) conv_state[:, :, -1] = hidden_states[:, :, 0] + inference_params.conv_states[self.layer_idx] = conv_state + bias = getattr(self.conv1d, "bias", 0.0) hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + bias) hidden_states = hidden_states.unsqueeze(-1) # (batch, intermediate_size, 1) From 91b81061165a5fbee2c9a9eb4439e7a5d06f0f21 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 13:19:37 +0100 Subject: [PATCH 089/116] should we always use float32? --- src/transformers/models/mamba/modeling_mamba.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 146c4cf19e01c7..06141d4fb9d311 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -208,7 +208,7 @@ def slow_forward(self, hidden_states, inference_params=None): # 3. State Space Model sequence transformation # 3.a. Selection: (batch, seq_len, self.time_step_rank + self.ssm_state_size * 2) - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) time_step, B, C = torch.split( ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) @@ -236,7 +236,7 @@ def slow_forward(self, hidden_states, inference_params=None): scan_outputs.append(scan_output[:, :, 0]) scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] scan_output = scan_output + (hidden_states * self.D[None, :, None].float()) - contextualized_states = (scan_output * self.act(gate)).to(hidden_states.dtype) # [batch, intermediade_size, seq_len] + contextualized_states = (scan_output * self.act(gate)).to(hidden_states.dtype) # 4. Final linear projection contextualized_states = self.out_proj(contextualized_states.transpose(1, 2)) # [batch, seq_len, hidden_size] From e8142caf5af8ef75b36de4c2cf655cb7502fc7fd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 13:20:29 +0100 Subject: [PATCH 090/116] nits --- src/transformers/models/mamba/modeling_mamba.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 06141d4fb9d311..14b2e793ca7f21 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -180,8 +180,8 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non inference_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection - selected_states = self.out_proj(scan_outputs.transpose(1, 2)) - return selected_states + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + return contextualized_states def slow_forward(self, hidden_states, inference_params=None): _, seq_len, _ = hidden_states.shape @@ -224,8 +224,6 @@ def slow_forward(self, hidden_states, inference_params=None): # [batch_size, intermediade_size, seq_len, 1] X [batch_size, seq_len, ssm_state_size, 1] deltaB_u = discrete_B * hidden_states[:, :, :, None] - deltaB_u = (discrete_time_step[:, :, :, None] * hidden_states[:, :, :, None]) * B[:, None, :, :].float() - # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] scan_outputs = [] @@ -236,9 +234,9 @@ def slow_forward(self, hidden_states, inference_params=None): scan_outputs.append(scan_output[:, :, 0]) scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] scan_output = scan_output + (hidden_states * self.D[None, :, None].float()) - contextualized_states = (scan_output * self.act(gate)).to(hidden_states.dtype) + scan_output = (scan_output * self.act(gate)).to(hidden_states.dtype) # 4. Final linear projection - contextualized_states = self.out_proj(contextualized_states.transpose(1, 2)) # [batch, seq_len, hidden_size] + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] inference_params.ssm_states[self.layer_idx] = ssm_state From 623b636123218d3f190b135f210002eddd73be44 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 13:21:45 +0100 Subject: [PATCH 091/116] revert wrong changes --- src/transformers/models/mamba/modeling_mamba.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 14b2e793ca7f21..adfb4195575825 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -232,19 +232,18 @@ def slow_forward(self, hidden_states, inference_params=None): # [batch_size, intermediade_size, ssm_state] X [batch_size, ssm_state] -> [batch_size, intermediade_size] scan_output = torch.matmul(ssm_state, C[:, i, :].unsqueeze(-1).float()) scan_outputs.append(scan_output[:, :, 0]) + inference_params.ssm_states[self.layer_idx] = ssm_state scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] scan_output = scan_output + (hidden_states * self.D[None, :, None].float()) scan_output = (scan_output * self.act(gate)).to(hidden_states.dtype) # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] - inference_params.ssm_states[self.layer_idx] = ssm_state - return contextualized_states def forward(self, hidden_states, inference_params=None): - # if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: - # return self.cuda_kernels_forward(hidden_states, inference_params) + if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: + return self.cuda_kernels_forward(hidden_states, inference_params) return self.slow_forward(hidden_states, inference_params) From 566c799c28bf66eb01e5df5cb6258d2303e43fdd Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 1 Mar 2024 13:22:08 +0100 Subject: [PATCH 092/116] res in float32 --- src/transformers/models/mamba/modeling_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index adfb4195575825..ba89422716dbea 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -294,7 +294,7 @@ def __init__(self, config, layer_idx): def forward(self, hidden_states, inference_params=None): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32 or True: + if self.residual_in_fp32: residual = residual.to(torch.float32) hidden_states = self.mixer(hidden_states, inference_params=inference_params) From 5d637d9fabee7f998ac74bfffc03c499adeda2bc Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 2 Mar 2024 04:03:08 +0100 Subject: [PATCH 093/116] cleanup --- .../models/mamba/modeling_mamba.py | 182 ++++++++++-------- 1 file changed, 97 insertions(+), 85 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index ba89422716dbea..4358460dccbdc1 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -39,13 +39,13 @@ logger = logging.get_logger(__name__) if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn + from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: logger.warning_once( "The `mamba_ssm` package is not installed in your environnement. Make sure to install it if you want to use the custom cuda kernels" ) - selective_state_update, selective_scan_fn = None, None + selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update @@ -123,122 +123,134 @@ def __init__(self, config, layer_idx): def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) - hidden_states, gate = projected_states.chunk(2, dim=1) - - # 2. Convolution sequence transformation - conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if inference_params is not None and inference_params.seqlen_offset > 0: - conv_state = inference_params.conv_states[self.layer_idx] - hidden_states = causal_conv1d_update( - hidden_states.squeeze(-1), conv_state, conv_weights, self.conv1d.bias, self.activation - ) - hidden_states = hidden_states.unsqueeze(-1) - else: - conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - inference_params.conv_states[self.layer_idx].copy_(conv_state) - hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) - - # 3. State Space Model sequence transformation - # 3.a. input varying initialization of time_step, B and C - ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) - time_step, B, C = torch.split( - ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 - ) - discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) - A = -torch.exp(self.A_log.float()) - # 3.c perform the recurrence y ← SSM(A, B, C)(x) - ssm_state = inference_params.ssm_states[self.layer_idx] - time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None - if inference_params is not None and inference_params.seqlen_offset > 0: - scan_outputs = selective_state_update( - ssm_state, - hidden_states[..., 0], - discrete_time_step[..., 0], - A, - B[:, 0], - C[:, 0], - self.D, - gate[..., 0], - time_proj_bias, - dt_softplus=True, - ).unsqueeze(-1) - else: - scan_outputs, last_state = selective_scan_fn( - hidden_states, - discrete_time_step, + if self.training and inference_params is None: # Doesn't support outputting the states -> used for training + contextualized_states = mamba_inner_fn( + projected_states, + self.conv1d.weight, + self.conv1d.bias, + self.x_proj.weight, + self.dt_proj.weight, + self.out_proj.weight, + self.out_proj.bias, A, - B.transpose(1, 2), - C.transpose(1, 2), + None, # input-dependent B + None, # input-dependent C self.D.float(), - gate, - time_proj_bias, + delta_bias=self.dt_proj.bias.float(), delta_softplus=True, - return_last_state=True, ) - if last_state is not None: - inference_params.ssm_states[self.layer_idx].copy_(ssm_state) + + else: + hidden_states, gate = projected_states.chunk(2, dim=1) + + # 2. Convolution sequence transformation + conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) + if inference_params is not None and inference_params.seqlen_offset > 0: + hidden_states = causal_conv1d_update( + hidden_states.squeeze(-1), inference_params.conv_states[self.layer_idx], conv_weights, self.conv1d.bias, self.activation + ) + hidden_states = hidden_states.unsqueeze(-1) + else: + inference_params.conv_states[self.layer_idx] = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) + + # 3. State Space Model sequence transformation + # 3.a. input varying initialization of time_step, B and C + ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) + time_step, B, C = torch.split( + ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 + ) + discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2) + + A = -torch.exp(self.A_log.float()) + # 3.c perform the recurrence y ← SSM(A, B, C)(x) + time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None + if inference_params is not None and inference_params.seqlen_offset > 0: + scan_outputs = selective_state_update( + inference_params.ssm_states[self.layer_idx], + hidden_states[..., 0], + discrete_time_step[..., 0], + A, + B[:, 0], + C[:, 0], + self.D, + gate[..., 0], + time_proj_bias, + dt_softplus=True, + ).unsqueeze(-1) + else: + scan_outputs, ssm_state = selective_scan_fn( + hidden_states, + discrete_time_step, + A, + B.transpose(1, 2), + C.transpose(1, 2), + self.D.float(), + gate, + time_proj_bias, + delta_softplus=True, + return_last_state=True, + ) + if ssm_state is not None: + inference_params.ssm_states[self.layer_idx] = ssm_state # 4. Final linear projection contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states - def slow_forward(self, hidden_states, inference_params=None): - _, seq_len, _ = hidden_states.shape - + def slow_forward(self, input_states, inference_params=None): + _, seq_len, _ = input_states.shape + dtype = input_states.dtype # 1. Gated MLP's linear projection - projected_states = self.in_proj(hidden_states).transpose(1, 2) # (batch, 2 * intermediate_size, seq_len) + projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation if inference_params.seqlen_offset > 0: - conv_state = inference_params.conv_states[self.layer_idx] # (batch, intermediate_size, conv_kernel_size) + conv_state = inference_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] conv_state = torch.roll(conv_state, shifts=-1, dims=-1) conv_state[:, :, -1] = hidden_states[:, :, 0] - inference_params.conv_states[self.layer_idx] = conv_state + inference_params.conv_states[self.layer_idx].copy_(conv_state) bias = getattr(self.conv1d, "bias", 0.0) - hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + bias) - hidden_states = hidden_states.unsqueeze(-1) # (batch, intermediate_size, 1) + hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + bias).to(dtype) + hidden_states = hidden_states.unsqueeze(-1) # (batch, intermediate_size, 1) : decoding else: inference_params.conv_states[self.layer_idx] = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) - ) - hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (batch, intermediate_size, seq_len) + ).to(inference_params.dtype) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (batch, intermediate_size, seq_len) # 3. State Space Model sequence transformation - # 3.a. Selection: (batch, seq_len, self.time_step_rank + self.ssm_state_size * 2) + # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] ssm_parameters = self.x_proj(hidden_states.transpose(1, 2)) time_step, B, C = torch.split( ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1 ) - discrete_time_step = self.dt_proj(time_step) # (batch, seq_len, intermediate_size) - discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) - - # 3.b. Discretization: B and C to [batch_size, seq_len, intermediate_size, ssm_state_size] (SRAM) - A = -torch.exp(self.A_log.float()) # (intermediate_size, ssm_state_size) - # [batch_size, intermediate_size, seq_len, 1] X [1, intermediate_size, 1, ssm_state_size] - discrete_A = torch.exp(discrete_time_step[:, :, :, None] * A[None, :, None, :]) - # [batch_size, intermediate_size, seq_len, 1] X [batch_size, 1, seq_len, ssm_state_size] - discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() - # [batch_size, intermediade_size, seq_len, 1] X [batch_size, seq_len, ssm_state_size, 1] - deltaB_u = discrete_B * hidden_states[:, :, :, None] + discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size] + discrete_time_step = nn.functional.softplus(discrete_time_step).transpose(1, 2) # [batch, intermediate_size, seq_len] + + # 3.b. Discretization: B and C to [batch, seq_len, intermediate_size, ssm_state_size] (SRAM) + A = -torch.exp(self.A_log.float()) # [intermediate_size, ssm_state_size] + discrete_A = torch.exp(A[None, :, None, :] * discrete_time_step[:, :, :, None]) # [batch, intermediate_size, seq_len, ssm_state_size] + discrete_B = discrete_time_step[:, :, :, None] * B[:, None, :, :].float() # [batch, intermediade_size, seq_len, ssm_state_size] + deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) ssm_state = inference_params.ssm_states[self.layer_idx] scan_outputs = [] for i in range(seq_len): - ssm_state = ssm_state * discrete_A[:, :, i, :] + deltaB_u[:, :, i, :] - # [batch_size, intermediade_size, ssm_state] X [batch_size, ssm_state] -> [batch_size, intermediade_size] - scan_output = torch.matmul(ssm_state, C[:, i, :].unsqueeze(-1).float()) + ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] + scan_output = torch.matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1)) # [batch, intermediade_size, 1] scan_outputs.append(scan_output[:, :, 0]) - inference_params.ssm_states[self.layer_idx] = ssm_state - scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] - scan_output = scan_output + (hidden_states * self.D[None, :, None].float()) - scan_output = (scan_output * self.act(gate)).to(hidden_states.dtype) - # 4. Final linear projection - contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] + scan_output = torch.stack(scan_outputs, dim=-1) # [batch, seq_len, intermediade_size] + scan_output = scan_output + (hidden_states * self.D[None, :, None]) + scan_output = (scan_output * self.act(gate)) + inference_params.ssm_states[self.layer_idx].copy_(ssm_state) + # 4. Final linear projection + contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states def forward(self, hidden_states, inference_params=None): @@ -250,7 +262,7 @@ def forward(self, hidden_states, inference_params=None): class MambaCache: def __init__(self, config, batch_size, dtype=torch.float16, device=None): self.seqlen_offset = 0 - + self.dtype = dtype intermediate_size = config.intermediate_size ssm_state_size = config.state_size conv_kernel_size = config.conv_kernel @@ -294,7 +306,7 @@ def __init__(self, config, layer_idx): def forward(self, hidden_states, inference_params=None): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32: + if self.residual_in_fp32 or True: residual = residual.to(torch.float32) hidden_states = self.mixer(hidden_states, inference_params=inference_params) From 648a29225e0495aa426383c1b68b34e45794361d Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 2 Mar 2024 04:10:50 +0100 Subject: [PATCH 094/116] skip fmt for now --- .../models/mamba/modeling_mamba.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 4358460dccbdc1..a51e7721cfed0e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -39,7 +39,7 @@ logger = logging.get_logger(__name__) if is_mamba_ssm_available(): - from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn + from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: logger.warning_once( @@ -133,7 +133,7 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non self.dt_proj.weight, self.out_proj.weight, self.out_proj.bias, - A, + -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C self.D.float(), @@ -148,12 +148,20 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) if inference_params is not None and inference_params.seqlen_offset > 0: hidden_states = causal_conv1d_update( - hidden_states.squeeze(-1), inference_params.conv_states[self.layer_idx], conv_weights, self.conv1d.bias, self.activation + hidden_states.squeeze(-1), + inference_params.conv_states[self.layer_idx], + conv_weights, + self.conv1d.bias, + self.activation, ) hidden_states = hidden_states.unsqueeze(-1) else: - inference_params.conv_states[self.layer_idx] = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) - hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation) + inference_params.conv_states[self.layer_idx] = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + hidden_states = causal_conv1d_fn( + hidden_states, conv_weights, self.conv1d.bias, activation=self.activation + ) # 3. State Space Model sequence transformation # 3.a. input varying initialization of time_step, B and C @@ -199,6 +207,7 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states + # fmt: off def slow_forward(self, input_states, inference_params=None): _, seq_len, _ = input_states.shape dtype = input_states.dtype @@ -215,12 +224,12 @@ def slow_forward(self, input_states, inference_params=None): bias = getattr(self.conv1d, "bias", 0.0) hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + bias).to(dtype) - hidden_states = hidden_states.unsqueeze(-1) # (batch, intermediate_size, 1) : decoding + hidden_states = hidden_states.unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: inference_params.conv_states[self.layer_idx] = nn.functional.pad( hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ).to(inference_params.dtype) - hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (batch, intermediate_size, seq_len) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] @@ -252,6 +261,7 @@ def slow_forward(self, input_states, inference_params=None): # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states + # fmt: on def forward(self, hidden_states, inference_params=None): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: From e306e891dae4e296df2d8522154e70c8cf5d14d9 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 2 Mar 2024 04:36:38 +0100 Subject: [PATCH 095/116] update generation values --- tests/models/mamba/test_modeling_mamba.py | 59 ++++++++++++----------- 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 3046670dc3cbd8..491c11304ebb97 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -14,11 +14,12 @@ # limitations under the License. +from parameterized import parameterized import unittest from unittest.util import safe_repr from transformers import AutoTokenizer, MambaConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, slow, torch_device, require_torch_multi_gpu from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -183,14 +184,15 @@ def create_and_check_state_equivalency(self, config, input_ids, *args): outputs = model(input_ids) output_whole = outputs.last_hidden_state - outputs = model(input_ids[:, :2], use_cache=True) + outputs = model(input_ids[:, :-1], use_cache=True) output_one = outputs.last_hidden_state # Using the state computed on the first inputs, we will get the same output - outputs = model(input_ids[:, 2:], inference_params=outputs.inference_params) + outputs = model(input_ids[:, -1:], inference_params=outputs.inference_params) output_two = outputs.last_hidden_state self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)) + # TODO the orignal mamba does not support decoding more than 1 token neither do we def create_and_check_forward_and_backwards(self, config, input_ids, *args, gradient_checkpointing=False): model = MambaForCausalLM(config) @@ -268,7 +270,7 @@ def test_config(self): def test_retain_grad_hidden_states_attentions(self): pass - # @require_torch_multi_gpu + @require_torch_multi_gpu def test_multi_gpu_data_parallel_forward(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -346,18 +348,15 @@ def setUp(self): self.model_id = "ArthurZ/mamba-2.8b" self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) - def test_simple_generate(self): - import torch - - from transformers import AutoTokenizer, MambaForCausalLM - + @parameterized.expand([(torch_device,), ("cpu",)]) + def test_simple_generate(self, device): tokenizer = AutoTokenizer.from_pretrained("ArthurZ/mamba-130m") tokenizer.pad_token = tokenizer.eos_token model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16) - model.to(torch_device) + model.to(device) model.config.use_cache = True - input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(torch_device) + input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device) with torch.no_grad(): logits = model(input_ids=input_ids).logits @@ -377,49 +376,53 @@ def test_simple_generate(self): out = model.generate(input_ids, do_sample=False, max_new_tokens=10) output_sentence = tokenizer.decode(out[0, :]) - self.assertEqual(output_sentence, "Hey how are you doing?\n\nA:\n\nI have a similar") + self.assertEqual(output_sentence, "Hey how are you doing?\n\nI'm so glad you're here.") - def test_simple_generate_cuda_kernels_tiny(self): - expected_output = "Hello my name is John of the Golden, and I am the Lord" + @parameterized.expand([(torch_device,), ("cpu",)]) + def test_simple_generate_cuda_kernels_tiny(self, device): + expected_output = "Hello my name is John and I am a newbie to the world" - input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) - model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16).to(torch_device) + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-130m", torch_dtype=torch.float16).to(device) output = model.generate(input_ids, max_new_tokens=10) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) + @parameterized.expand([(torch_device,), ("cpu",)]) @slow - def test_simple_generate_cuda_kernels_small(self): - expected_output = "Hello my name is\n\nI am a student of the art of" + def test_simple_generate_cuda_kernels_small(self, device): + expected_output = "Hello my name is\n\nI am a\n\nI am a" - input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) - model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(torch_device) + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-790m", torch_dtype=torch.float16).to(device) output = model.generate(input_ids, max_new_tokens=10) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) + @parameterized.expand([(torch_device,), ("cpu",)]) @slow - def test_simple_generate_cuda_kernels_mid(self): - expected_output = "Hello my name is John and I am a software engineer. I have" + def test_simple_generate_cuda_kernels_mid(self, device): + expected_output = "Hello my name is John and I am a\n\nI am a" - input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) - model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(torch_device) + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(device) output = model.generate(input_ids, max_new_tokens=10) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) + @parameterized.expand([(torch_device,), ("cpu",)]) @slow - def test_simple_generate_cuda_kernels_big(self): - expected_output = "Hello my name is John. I am a student at the University of" + def test_simple_generate_cuda_kernels_big(self, device): + expected_output = "Hello my name is John and I am a new member of this forum" - input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(torch_device) - model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(torch_device) + input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device) + model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(device) output = model.generate(input_ids, max_new_tokens=10) output_sentence = self.tokenizer.decode(output[0].tolist()) From 057d7a3db117d5f8700c36af1dbd1347e0cce4ce Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 2 Mar 2024 04:44:32 +0100 Subject: [PATCH 096/116] update test values running original model --- tests/models/mamba/test_modeling_mamba.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 491c11304ebb97..64851222754d47 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -358,6 +358,10 @@ def test_simple_generate(self, device): model.config.use_cache = True input_ids = tokenizer("Hey how are you doing?", return_tensors="pt")["input_ids"].to(device) + out = model.generate(input_ids, do_sample=False, max_new_tokens=10) + output_sentence = tokenizer.decode(out[0, :]) + self.assertEqual(output_sentence, "Hey how are you doing?\n\nI'm so glad you're here.") + with torch.no_grad(): logits = model(input_ids=input_ids).logits @@ -372,11 +376,8 @@ def test_simple_generate(self, device): ] ,dtype=torch.float32) # fmt: skip - torch.testing.assert_close(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD) + torch.testing.assert_close(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3) - out = model.generate(input_ids, do_sample=False, max_new_tokens=10) - output_sentence = tokenizer.decode(out[0, :]) - self.assertEqual(output_sentence, "Hey how are you doing?\n\nI'm so glad you're here.") @parameterized.expand([(torch_device,), ("cpu",)]) def test_simple_generate_cuda_kernels_tiny(self, device): @@ -406,12 +407,12 @@ def test_simple_generate_cuda_kernels_small(self, device): @parameterized.expand([(torch_device,), ("cpu",)]) @slow def test_simple_generate_cuda_kernels_mid(self, device): - expected_output = "Hello my name is John and I am a\n\nI am a" + expected_output = "Hello my name is John and I am a\n\nI am a single father of a beautiful daughter. I am a" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device) model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-1.4b", torch_dtype=torch.float16).to(device) - output = model.generate(input_ids, max_new_tokens=10) + output = model.generate(input_ids, max_new_tokens=20) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) @@ -419,12 +420,12 @@ def test_simple_generate_cuda_kernels_mid(self, device): @parameterized.expand([(torch_device,), ("cpu",)]) @slow def test_simple_generate_cuda_kernels_big(self, device): - expected_output = "Hello my name is John and I am a new member of this forum" + expected_output = "Hello my name is John and I am a new member of this forum. I am a retired Marine and I am a member of the Marine Corps League. I am a" input_ids = self.tokenizer("Hello my name is", return_tensors="pt").input_ids.to(device) model = MambaForCausalLM.from_pretrained("ArthurZ/mamba-2.8b", torch_dtype=torch.float16).to(device) - output = model.generate(input_ids, max_new_tokens=10) + output = model.generate(input_ids, max_new_tokens=30) output_sentence = self.tokenizer.decode(output[0].tolist()) self.assertEqual(output_sentence, expected_output) From 72f8936c74b8bb7f5e7dca000882f273e7f9e8f5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Sat, 2 Mar 2024 04:50:14 +0100 Subject: [PATCH 097/116] fixup --- tests/models/mamba/test_modeling_mamba.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 64851222754d47..928621fdd873ce 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -14,12 +14,13 @@ # limitations under the License. -from parameterized import parameterized import unittest from unittest.util import safe_repr +from parameterized import parameterized + from transformers import AutoTokenizer, MambaConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device, require_torch_multi_gpu +from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -378,7 +379,6 @@ def test_simple_generate(self, device): torch.testing.assert_close(logits[0, 0, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3) - @parameterized.expand([(torch_device,), ("cpu",)]) def test_simple_generate_cuda_kernels_tiny(self, device): expected_output = "Hello my name is John and I am a newbie to the world" From f415081d05a061719d61431efe8731c7b5cf5265 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 01:54:25 +0100 Subject: [PATCH 098/116] update tests + rename inference_params to cache_params + make sure training does not use cache_params --- .../models/mamba/modeling_mamba.py | 132 +++++++++--------- tests/models/mamba/test_modeling_mamba.py | 28 ++-- 2 files changed, 83 insertions(+), 77 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index a51e7721cfed0e..999d0429e0815e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -120,11 +120,11 @@ def __init__(self, config, layer_idx): self.D = nn.Parameter(torch.ones(self.intermediate_size)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) - def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=None): + def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params=None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) - if self.training and inference_params is None: # Doesn't support outputting the states -> used for training + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training contextualized_states = mamba_inner_fn( projected_states, self.conv1d.weight, @@ -146,19 +146,21 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non # 2. Convolution sequence transformation conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)) - if inference_params is not None and inference_params.seqlen_offset > 0: + if cache_params is not None and cache_params.seqlen_offset > 0: hidden_states = causal_conv1d_update( hidden_states.squeeze(-1), - inference_params.conv_states[self.layer_idx], + cache_params.conv_states[self.layer_idx], conv_weights, self.conv1d.bias, self.activation, ) hidden_states = hidden_states.unsqueeze(-1) else: - inference_params.conv_states[self.layer_idx] = nn.functional.pad( - hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) - ) + if cache_params is not None: + conv_states = nn.functional.pad( + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) + cache_params.conv_states[self.layer_idx].copy_(conv_states) hidden_states = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, activation=self.activation ) @@ -174,9 +176,9 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non A = -torch.exp(self.A_log.float()) # 3.c perform the recurrence y ← SSM(A, B, C)(x) time_proj_bias = self.dt_proj.bias.float() if hasattr(self.dt_proj, "bias") else None - if inference_params is not None and inference_params.seqlen_offset > 0: + if cache_params is not None and cache_params.seqlen_offset > 0: scan_outputs = selective_state_update( - inference_params.ssm_states[self.layer_idx], + cache_params.ssm_states[self.layer_idx], hidden_states[..., 0], discrete_time_step[..., 0], A, @@ -201,34 +203,38 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, inference_params=Non return_last_state=True, ) if ssm_state is not None: - inference_params.ssm_states[self.layer_idx] = ssm_state + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) - # 4. Final linear projection - contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) + # 4. Final linear projection + contextualized_states = self.out_proj(scan_outputs.transpose(1, 2)) return contextualized_states # fmt: off - def slow_forward(self, input_states, inference_params=None): - _, seq_len, _ = input_states.shape + def slow_forward(self, input_states, cache_params=None): + batch_size, seq_len, _ = input_states.shape dtype = input_states.dtype # 1. Gated MLP's linear projection projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len] hidden_states, gate = projected_states.chunk(2, dim=1) # 2. Convolution sequence transformation - if inference_params.seqlen_offset > 0: - conv_state = inference_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] - conv_state = torch.roll(conv_state, shifts=-1, dims=-1) - conv_state[:, :, -1] = hidden_states[:, :, 0] - inference_params.conv_states[self.layer_idx].copy_(conv_state) - - bias = getattr(self.conv1d, "bias", 0.0) - hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + bias).to(dtype) - hidden_states = hidden_states.unsqueeze(-1) # [batch, intermediate_size, 1] : decoding + if cache_params is not None: + ssm_state = cache_params.ssm_states[self.layer_idx] + if cache_params.seqlen_offset > 0: + conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size] + conv_state = torch.roll(conv_state, shifts=-1, dims=-1) + conv_state[:, :, -1] = hidden_states[:, :, 0] + cache_params.conv_states[self.layer_idx].copy_(conv_state) + + bias = getattr(self.conv1d, "bias", 0.0) + hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + bias).to(dtype) + hidden_states = hidden_states.unsqueeze(-1) # [batch, intermediate_size, 1] : decoding + else: + conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + cache_params.conv_states[self.layer_idx].copy_(conv_state) + hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] else: - inference_params.conv_states[self.layer_idx] = nn.functional.pad( - hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) - ).to(inference_params.dtype) + ssm_state = torch.zeros((batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] # 3. State Space Model sequence transformation @@ -247,7 +253,6 @@ def slow_forward(self, input_states, inference_params=None): deltaB_u = discrete_B * hidden_states[:, :, :, None].float() # 3.c perform the recurrence y ← SSM(A, B, C)(x) - ssm_state = inference_params.ssm_states[self.layer_idx] scan_outputs = [] for i in range(seq_len): ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :] # [batch, intermediade_size, ssm_state] @@ -257,16 +262,18 @@ def slow_forward(self, input_states, inference_params=None): scan_output = scan_output + (hidden_states * self.D[None, :, None]) scan_output = (scan_output * self.act(gate)) - inference_params.ssm_states[self.layer_idx].copy_(ssm_state) + if cache_params is not None: + cache_params.ssm_states[self.layer_idx].copy_(ssm_state) + # 4. Final linear projection contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size] return contextualized_states # fmt: on - def forward(self, hidden_states, inference_params=None): + def forward(self, hidden_states, cache_params=None): if is_fast_path_available and "cuda" in self.x_proj.weight.device.type: - return self.cuda_kernels_forward(hidden_states, inference_params) - return self.slow_forward(hidden_states, inference_params) + return self.cuda_kernels_forward(hidden_states, cache_params) + return self.slow_forward(hidden_states, cache_params) class MambaCache: @@ -313,13 +320,13 @@ def __init__(self, config, layer_idx): self.norm = MambaRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mixer = MambaMixer(config, layer_idx=layer_idx) - def forward(self, hidden_states, inference_params=None): + def forward(self, hidden_states, cache_params=None): residual = hidden_states hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype)) - if self.residual_in_fp32 or True: + if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mixer(hidden_states, inference_params=inference_params) + hidden_states = self.mixer(hidden_states, cache_params=cache_params) hidden_states = residual + hidden_states return hidden_states @@ -391,7 +398,7 @@ class MambaOutput(ModelOutput): Args: last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. - inference_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. @@ -404,7 +411,7 @@ class MambaOutput(ModelOutput): """ last_hidden_state: torch.FloatTensor = None - inference_params: Optional[List[torch.FloatTensor]] = None + cache_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None @@ -418,7 +425,7 @@ class MambaCausalLMOutput(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - inference_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): + cache_params (list of five `torch.FloatTensor` of shape `(batch_size, hidden_size, num_hidden_layers)`): The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to avoid providing the old `input_ids`. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): @@ -430,7 +437,7 @@ class MambaCausalLMOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None - inference_params: Optional[List[torch.FloatTensor]] = None + cache_params: Optional[List[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None @@ -455,7 +462,7 @@ class MambaCausalLMOutput(ModelOutput): input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): Indices of input sequence tokens in the vocabulary. - If `inference_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as + If `cache_params.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as `input_ids`. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and @@ -466,11 +473,11 @@ class MambaCausalLMOutput(ModelOutput): Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert `input_ids` indices into associated vectors than the model's internal embedding lookup matrix. - inference_params (`MambaCache`, *optional*): + cache_params (`MambaCache`, *optional*): If passed along, the model uses the previous state in all the blocks (which will give the output for the `input_ids` provided as if the model add `state_input_ids + input_ids` as context). use_cache (`bool`, *optional*): - If set to `True`, the `inference_params` is returned and can be used to quickly generate the next logits. + If set to `True`, the `cache_params` is returned and can be used to quickly generate the next logits. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. @@ -511,7 +518,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, - inference_params: Optional[List[torch.FloatTensor]] = None, + cache_params: Optional[List[torch.FloatTensor]] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -531,33 +538,28 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) - if inference_params is None: - inference_params = MambaCache( - self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype - ) - if self.gradient_checkpointing and self.training: if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`" - ) use_cache = False + if cache_params is None and use_cache: + cache_params = MambaCache( + self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype + ) + hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None for mixer_block in self.layers: if self.gradient_checkpointing and self.training: - hidden_states = self._gradient_checkpointing_func( - mixer_block.__call__, hidden_states, inference_params - ) + hidden_states = self._gradient_checkpointing_func(mixer_block.__call__, hidden_states, cache_params) else: - hidden_states = mixer_block(hidden_states, inference_params=inference_params) + hidden_states = mixer_block(hidden_states, cache_params=cache_params) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if use_cache: - inference_params.seqlen_offset += inputs_embeds.shape[1] + cache_params.seqlen_offset += inputs_embeds.shape[1] hidden_states = self.norm_f(hidden_states) @@ -565,11 +567,11 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, inference_params, all_hidden_states] if v is not None) + return tuple(v for v in [hidden_states, cache_params, all_hidden_states] if v is not None) return MambaOutput( last_hidden_state=hidden_states, - inference_params=inference_params if use_cache else None, + cache_params=cache_params if use_cache else None, hidden_states=all_hidden_states, ) @@ -606,22 +608,22 @@ def set_input_embeddings(self, new_embeddings): def _update_model_kwargs_for_generation( self, outputs: ModelOutput, model_kwargs: Dict[str, Any], **kwargs ) -> Dict[str, Any]: - model_kwargs["inference_params"] = outputs["inference_params"] + model_kwargs["cache_params"] = outputs["cache_params"] return model_kwargs def prepare_inputs_for_generation( - self, input_ids, inference_params=None, inputs_embeds=None, attention_mask=None, **kwargs + self, input_ids, cache_params=None, inputs_embeds=None, attention_mask=None, **kwargs ): # only last token for inputs_ids if the state is passed along. - if inference_params is not None: + if cache_params is not None: input_ids = input_ids[:, -1].unsqueeze(-1) - if inputs_embeds is not None and inference_params is None: + if inputs_embeds is not None and cache_params is None: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} - model_inputs["inference_params"] = inference_params + model_inputs["cache_params"] = cache_params return model_inputs @add_start_docstrings_to_model_forward(MAMBA_INPUTS_DOCSTRING) @@ -634,7 +636,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - inference_params: Optional[torch.FloatTensor] = None, + cache_params: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, @@ -650,7 +652,7 @@ def forward( mamba_outputs = self.backbone( input_ids, - inference_params=inference_params, + cache_params=cache_params, inputs_embeds=inputs_embeds, output_hidden_states=output_hidden_states, return_dict=return_dict, @@ -677,6 +679,6 @@ def forward( return MambaCausalLMOutput( loss=loss, logits=logits, - inference_params=mamba_outputs.inference_params, + cache_params=mamba_outputs.cache_params, hidden_states=mamba_outputs.hidden_states, ) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 928621fdd873ce..f536b634930870 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -14,6 +14,7 @@ # limitations under the License. +import math import unittest from unittest.util import safe_repr @@ -189,7 +190,7 @@ def create_and_check_state_equivalency(self, config, input_ids, *args): output_one = outputs.last_hidden_state # Using the state computed on the first inputs, we will get the same output - outputs = model(input_ids[:, -1:], inference_params=outputs.inference_params) + outputs = model(input_ids[:, -1:], cache_params=outputs.cache_params) output_two = outputs.last_hidden_state self.parent.assertTrue(torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)) @@ -277,7 +278,7 @@ def test_multi_gpu_data_parallel_forward(self): # some params shouldn't be scattered by nn.DataParallel # so just remove them if they are present. - blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"] + blacklist_non_batched_params = ["cache_params"] for k in blacklist_non_batched_params: inputs_dict.pop(k, None) @@ -314,18 +315,24 @@ def test_initialization(self): for model_class in self.all_model_classes: model = model_class(config=config) for name, param in model.named_parameters(): - if "A" in name: + if "dt_proj.bias" in name: + dt = torch.exp( + torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min)) + + math.log(config.time_step_min) + ).clamp(min=config.time_step_floor) + inv_dt = dt + torch.log(-torch.expm1(-dt)) if param.requires_grad: - self.assertTrue(param.data.max().item() == 3.0) - self.assertTrue(param.data.min().item() == -5.0) - elif "B" in name: + self.assertTrue(param.data.max().item() <= inv_dt[1]) + self.assertTrue(param.data.min().item() >= inv_dt[0]) + elif "A_log" in name: + A = torch.arange(1, config.state_size + 1, dtype=torch.float32)[None, :] + self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5)) + elif "D" in name: if param.requires_grad: # check if it's a ones like self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5)) - # TODO handle initialization scheme! - - @unittest.skip("Mamba does not use attention equivalent test should be `test_ssm_outputs`") + @unittest.skip("Mamba does not use attention") def test_attention_outputs(self): r""" Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models @@ -333,9 +340,6 @@ def test_attention_outputs(self): """ pass - def test_ssm_outputs(self): - pass - @slow def test_model_from_pretrained(self): for model_name in MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: From 6bb659a86b4dd9f470a9e7eda1fa2a39de2c5add Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 02:38:30 +0100 Subject: [PATCH 099/116] small nits --- .../models/mamba/modeling_mamba.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 999d0429e0815e..564cd757edcb00 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -91,7 +91,7 @@ def __init__(self, config, layer_idx): self.intermediate_size = config.intermediate_size self.time_step_rank = config.time_step_rank self.layer_idx = layer_idx - + self.use_conv_bias = config.use_conv_bias self.conv1d = nn.Conv1d( in_channels=self.intermediate_size, out_channels=self.intermediate_size, @@ -119,20 +119,21 @@ def __init__(self, config, layer_idx): self.A_log = nn.Parameter(torch.log(A)) self.D = nn.Parameter(torch.ones(self.intermediate_size)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) + self.use_bias = config.use_bias def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params=None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) - if self.training and cache_params is None: # Doesn't support outputting the states -> used for training + if self.training and cache_params is None or False: # Doesn't support outputting the states -> used for training contextualized_states = mamba_inner_fn( projected_states, self.conv1d.weight, - self.conv1d.bias, + self.conv1d.bias if self.use_conv_bias else None, self.x_proj.weight, self.dt_proj.weight, self.out_proj.weight, - self.out_proj.bias, + self.out_proj.bias.float() if self.use_bias else None, -torch.exp(self.A_log.float()), None, # input-dependent B None, # input-dependent C @@ -202,7 +203,7 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params=None): delta_softplus=True, return_last_state=True, ) - if ssm_state is not None: + if ssm_state is not None and cache_params is not None: cache_params.ssm_states[self.layer_idx].copy_(ssm_state) # 4. Final linear projection @@ -225,10 +226,10 @@ def slow_forward(self, input_states, cache_params=None): conv_state = torch.roll(conv_state, shifts=-1, dims=-1) conv_state[:, :, -1] = hidden_states[:, :, 0] cache_params.conv_states[self.layer_idx].copy_(conv_state) - - bias = getattr(self.conv1d, "bias", 0.0) - hidden_states = self.act(torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + bias).to(dtype) - hidden_states = hidden_states.unsqueeze(-1) # [batch, intermediate_size, 1] : decoding + hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) + if self.use_conv_bias: + hidden_states += self.conv1d.bias + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_states[self.layer_idx].copy_(conv_state) @@ -538,9 +539,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embeddings(input_ids) - if self.gradient_checkpointing and self.training: - if use_cache: - use_cache = False + if self.gradient_checkpointing and self.training and use_cache: + use_cache = False if cache_params is None and use_cache: cache_params = MambaCache( From 178fe76b8759085f3178fab8bcd1ee7df178c161 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 03:31:56 +0100 Subject: [PATCH 100/116] more nits --- src/transformers/models/mamba/modeling_mamba.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 564cd757edcb00..fb48c864abbe21 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -55,9 +55,8 @@ ) causal_conv1d_update, causal_conv1d_fn = None, None -is_fast_path_available = ( - any((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update)) is not None -) +selective_state_update = None +is_fast_path_available = all((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)) _CHECKPOINT_FOR_DOC = "ArthurZ/mamba-130m" _CONFIG_FOR_DOC = "MambaConfig" @@ -125,7 +124,7 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params=None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) - if self.training and cache_params is None or False: # Doesn't support outputting the states -> used for training + if self.training and cache_params is None: # Doesn't support outputting the states -> used for training contextualized_states = mamba_inner_fn( projected_states, self.conv1d.weight, From 3a46724eaf48b431cf4f0d6d043493dba9cd1c88 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 03:41:37 +0100 Subject: [PATCH 101/116] fix final CIs --- tests/models/mamba/test_modeling_mamba.py | 63 ++++++++++++++++++++++- utils/check_config_attributes.py | 2 + 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index f536b634930870..75c98672b0e685 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -17,12 +17,12 @@ import math import unittest from unittest.util import safe_repr - +from typing import Tuple, List, Dict from parameterized import parameterized from transformers import AutoTokenizer, MambaConfig, is_torch_available from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device - +from transformers.models.mamba.modeling_mamba import MambaCache from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor @@ -347,6 +347,65 @@ def test_model_from_pretrained(self): self.assertIsNotNone(model) + + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, MambaCache): # MODIFIED PART START + recursive_check(tuple_object.conv_states, dict_object.conv_states) + recursive_check(tuple_object.ssm_states, dict_object.ssm_states) + elif isinstance(tuple_object, (List, Tuple)): # MODIFIED PART END + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose(tuple_object, dict_object, atol=1e-5), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + + @require_torch class MambaIntegrationTests(unittest.TestCase): def setUp(self): diff --git a/utils/check_config_attributes.py b/utils/check_config_attributes.py index da4a1210357daf..9701b8293407af 100644 --- a/utils/check_config_attributes.py +++ b/utils/check_config_attributes.py @@ -34,6 +34,8 @@ SPECIAL_CASES_TO_ALLOW = { # used to compute the property `self.chunk_length` "EncodecConfig": ["overlap"], + # used as in the config to define `intermediate_size` + "MambaConfig": ["expand"], # used as `self.bert_model = BertModel(config, ...)` "DPRConfig": True, "FuyuConfig": True, From 13204e083863c37237d1898ca6776abc694ae8f5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 03:42:47 +0100 Subject: [PATCH 102/116] style --- src/transformers/models/mamba/modeling_mamba.py | 4 +++- tests/models/mamba/test_modeling_mamba.py | 11 +++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index fb48c864abbe21..aea9370343d0e5 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -56,7 +56,9 @@ causal_conv1d_update, causal_conv1d_fn = None, None selective_state_update = None -is_fast_path_available = all((selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)) +is_fast_path_available = all( + (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) +) _CHECKPOINT_FOR_DOC = "ArthurZ/mamba-130m" _CONFIG_FOR_DOC = "MambaConfig" diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 75c98672b0e685..72f057c539bf1b 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -16,13 +16,15 @@ import math import unittest +from typing import Dict, List, Tuple from unittest.util import safe_repr -from typing import Tuple, List, Dict + from parameterized import parameterized from transformers import AutoTokenizer, MambaConfig, is_torch_available -from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device from transformers.models.mamba.modeling_mamba import MambaCache +from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device + from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor @@ -346,8 +348,6 @@ def test_model_from_pretrained(self): model = MambaModel.from_pretrained(model_name) self.assertIsNotNone(model) - - def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -357,7 +357,7 @@ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() def recursive_check(tuple_object, dict_object): - if isinstance(tuple_object, MambaCache): # MODIFIED PART START + if isinstance(tuple_object, MambaCache): # MODIFIED PART START recursive_check(tuple_object.conv_states, dict_object.conv_states) recursive_check(tuple_object.ssm_states, dict_object.ssm_states) elif isinstance(tuple_object, (List, Tuple)): # MODIFIED PART END @@ -405,7 +405,6 @@ def recursive_check(tuple_object, dict_object): check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) - @require_torch class MambaIntegrationTests(unittest.TestCase): def setUp(self): From 1608a9052d725e9d4a18066808e15910d17ddc83 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 03:45:19 +0100 Subject: [PATCH 103/116] nit doc --- src/transformers/models/mamba/configuration_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 6e0a9b442d1e7f..1f1c97c53e84d6 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -67,7 +67,7 @@ class MambaConfig(PretrainedConfig): The non-linear activation function (function or string) in the decoder. initializer_range (`float`, *optional*, defaults to 0.1): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - residual_in_fp32 (`bool`, *optional*, defaults to `False`): + residual_in_fp32 (`bool`, *optional*, defaults to `True`): Whether or not residuals should be in `float32`. time_step_rank (`int`, *optional*, defaults to `"auto"`): rank fo the discretization projection matrix. time_step_scale (``, *optional*, defaults to 1.0): From 99119ba20ac92d5649463069497412fcd9ab834b Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 03:50:02 +0100 Subject: [PATCH 104/116] I hope final doc nits --- .../models/mamba/configuration_mamba.py | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 1f1c97c53e84d6..aeab6585d54c8b 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -69,13 +69,21 @@ class MambaConfig(PretrainedConfig): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. residual_in_fp32 (`bool`, *optional*, defaults to `True`): Whether or not residuals should be in `float32`. - time_step_rank (`int`, *optional*, defaults to `"auto"`): rank fo the discretization projection matrix. - time_step_scale (``, *optional*, defaults to 1.0): - time_step_min (``, *optional*, defaults to 0.001): - time_step_max (``, *optional*, defaults to 0.1): - time_step_init_scheme (``, *optional*, defaults to `"random"`): - time_step_floor (``, *optional*, defaults to 0.0001): - rescale_prenorm_residual (``, *optional*, defaults to `False`): + time_step_rank (`int`, *optional*, defaults to `"auto"`): + Rank of the the discretization projection matrix. + time_step_scale (`float`, *optional*, defaults to 1.0): + Scale used used to scale `dt_proj.bias`. + time_step_min (`float`, *optional*, defaults to 0.001): + Minimum `time_step` used to bound `dt_proj.bias`. + time_step_max (`float`, *optional*, defaults to 0.1): + Maximum `time_step` used to bound `dt_proj.bias`. + time_step_init_scheme (`float`, *optional*, defaults to `"random"`): + Init scheme used for `dt_proj.weight`. Should be one of `["random","uniform"]` + time_step_floor (`float`, *optional*, defaults to 0.0001): + Minimum clamping value of the `dt_proj.bias` layer initialization. + rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): + Whether or not to rescale `out_proj` weights when initializing. + use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. From d6fb1efa6d093b0afa3a17444917b396ffc616d5 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 03:50:40 +0100 Subject: [PATCH 105/116] nit --- src/transformers/models/mamba/configuration_mamba.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index aeab6585d54c8b..2839a14616a6f1 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -83,7 +83,6 @@ class MambaConfig(PretrainedConfig): Minimum clamping value of the `dt_proj.bias` layer initialization. rescale_prenorm_residual (`bool`, *optional*, defaults to `False`): Whether or not to rescale `out_proj` weights when initializing. - use_cache (`bool`, *optional*, defaults to `True`): Whether or not the cache should be used. From 844530fdf443ae93226373aa1ba77c5fa4731fad Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 03:52:54 +0100 Subject: [PATCH 106/116] =?UTF-8?q?=F0=9F=AB=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/models/mamba/configuration_mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 2839a14616a6f1..965039386106da 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -58,7 +58,7 @@ class MambaConfig(PretrainedConfig): The id of the end of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as GPTNeoX. expand (`int`, *optional*, defaults to 2): Expanding factor used to determin the intermediate size. - conv_kernel (``, *optional*, defaults to 4): + conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. use_bias (`bool`, *optional*, defaults to `False`): Whether or not to use bias in ["in_proj", "out_proj"] of the mixer block use_conv_bias (`bool`, *optional*, defaults to `True`): From 52be01857a4ea009b4d8491131c635c8d1157476 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 03:56:35 +0100 Subject: [PATCH 107/116] final touch! --- src/transformers/models/mamba/modeling_mamba.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index aea9370343d0e5..e1e09a694aaae3 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -55,7 +55,6 @@ ) causal_conv1d_update, causal_conv1d_fn = None, None -selective_state_update = None is_fast_path_available = all( (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn) ) From d03de1c1070cee0fb7b9358669abc3f68ef56b75 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Mon, 4 Mar 2024 04:18:34 +0100 Subject: [PATCH 108/116] fix torch import --- tests/models/mamba/test_modeling_mamba.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index 72f057c539bf1b..e9e247954f4360 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -22,7 +22,6 @@ from parameterized import parameterized from transformers import AutoTokenizer, MambaConfig, is_torch_available -from transformers.models.mamba.modeling_mamba import MambaCache from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin @@ -39,6 +38,8 @@ MambaForCausalLM, MambaModel, ) + from transformers.models.mamba.modeling_mamba import MambaCache + from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 else: is_torch_greater_or_equal_than_2_0 = False From c0672a8b5ab1b698ca5ab027373435899a523208 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 5 Mar 2024 01:59:21 +0100 Subject: [PATCH 109/116] Apply suggestions from code review Co-authored-by: Lysandre Debut --- docs/source/en/model_doc/mamba.md | 2 +- .../models/mamba/configuration_mamba.py | 2 +- .../models/mamba/modeling_mamba.py | 19 +++++++++---------- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 70218937fe8186..65b09c9f789030 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -34,7 +34,7 @@ Tips: - in order to run the fast version of the model you should install `causal_conv1d` and `mamba` -This model was contributed by [INSERT YOUR HF USERNAME HERE](https://huggingface.co/ArthurZ). +This model was contributed by [ArthurZ](https://huggingface.co/ArthurZ). The original code can be found [here](https://github.com/state-spaces/mamba). diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 965039386106da..6ba5efd717b364 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -31,7 +31,7 @@ class MambaConfig(PretrainedConfig): """ This is the configuration class to store the configuration of a [`MambaModel`]. It is used to instantiate a MAMBA model according to the specified arguments, defining the model architecture. Instantiating a configuration with the - defaults will yield a similar configuration to that of the RWVK-4 + defaults will yield a similar configuration to that of the MAMBA [state-spaces/mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index e1e09a694aaae3..228cb1ddcc207f 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -62,14 +62,7 @@ _CHECKPOINT_FOR_DOC = "ArthurZ/mamba-130m" _CONFIG_FOR_DOC = "MambaConfig" -MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [ - "ArthurZ/mamba-130m", - "state-spaces/mamba-370m", - "state-spaces/mamba-790m", - "state-spaces/mamba-1.4b", - "state-spaces/mamba-2.8b", - "state-spaces/mamba-2.8b-slimpj", -] # See all Mamba models at https://huggingface.co/models?filter=mamba +MAMBA_PRETRAINED_MODEL_ARCHIVE_LIST = [] # See all Mamba models at https://huggingface.co/models?filter=mamba class MambaMixer(nn.Module): @@ -231,11 +224,17 @@ def slow_forward(self, input_states, cache_params=None): hidden_states += self.conv1d.bias hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: - conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0)) + conv_state = nn.functional.pad( + hidden_states, + (self.conv_kernel_size - hidden_states.shape[-1], 0) + ) cache_params.conv_states[self.layer_idx].copy_(conv_state) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] else: - ssm_state = torch.zeros((batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype) + ssm_state = torch.zeros( + (batch_size, self.intermediate_size, self.ssm_state_size), + device=hidden_states.device, dtype=dtype + ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] # 3. State Space Model sequence transformation From dfc1212d7466aa7176415452ae271582f019e0db Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 5 Mar 2024 01:59:54 +0100 Subject: [PATCH 110/116] Apply suggestions from code review --- .../models/mamba/configuration_mamba.py | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/mamba/configuration_mamba.py b/src/transformers/models/mamba/configuration_mamba.py index 6ba5efd717b364..dd1dd129aec633 100644 --- a/src/transformers/models/mamba/configuration_mamba.py +++ b/src/transformers/models/mamba/configuration_mamba.py @@ -1,6 +1,5 @@ # coding=utf-8 -# Copyright 2023 The OpenAI Team Authors and HuggingFace Inc. team. -# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# Copyright 2024 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -51,12 +50,10 @@ class MambaConfig(PretrainedConfig): The epsilon to use in the layer normalization layers. pad_token_id (`int`, *optional*, defaults to 0): Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - The id of the beginning of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer - as GPTNeoX. - eos_token_id (`int`, *optional*, defaults to 2): - The id of the end of sentence token in the vocabulary. Defaults to 0 as MAMBA uses the same tokenizer as - GPTNeoX. + bos_token_id (`int`, *optional*, defaults to 0): + The id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 0): + The id of the end of sentence token in the vocabulary. expand (`int`, *optional*, defaults to 2): Expanding factor used to determin the intermediate size. conv_kernel (`int`, *optional*, defaults to 4): Size of the convolution kernel. use_bias (`bool`, *optional*, defaults to `False`): @@ -68,9 +65,9 @@ class MambaConfig(PretrainedConfig): initializer_range (`float`, *optional*, defaults to 0.1): The standard deviation of the truncated_normal_initializer for initializing all weight matrices. residual_in_fp32 (`bool`, *optional*, defaults to `True`): - Whether or not residuals should be in `float32`. - time_step_rank (`int`, *optional*, defaults to `"auto"`): - Rank of the the discretization projection matrix. + Whether or not residuals should be in `float32`. If set to `False` residuals will keep the same `dtype` as the rest of the model + time_step_rank (`Union[int,str]`, *optional*, defaults to `"auto"`): + Rank of the the discretization projection matrix. `"auto"` means that it will default to `math.ceil(self.hidden_size / 16)` time_step_scale (`float`, *optional*, defaults to 1.0): Scale used used to scale `dt_proj.bias`. time_step_min (`float`, *optional*, defaults to 0.001): @@ -112,8 +109,8 @@ def __init__( num_hidden_layers=32, layer_norm_epsilon=1e-5, pad_token_id=0, - bos_token_id=1, - eos_token_id=2, + bos_token_id=0, + eos_token_id=0, expand=2, conv_kernel=4, use_bias=False, From acd4ccf1a68745c50bb6cedb7b0ba886d7c5930f Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 5 Mar 2024 10:23:52 +0900 Subject: [PATCH 111/116] fix fix and fix --- src/transformers/kernels/mamba/Makefile | 28 --------------- .../models/mamba/modeling_mamba.py | 34 +++++++++---------- tests/models/mamba/test_modeling_mamba.py | 1 - 3 files changed, 16 insertions(+), 47 deletions(-) delete mode 100644 src/transformers/kernels/mamba/Makefile diff --git a/src/transformers/kernels/mamba/Makefile b/src/transformers/kernels/mamba/Makefile deleted file mode 100644 index f4dec8688abc3e..00000000000000 --- a/src/transformers/kernels/mamba/Makefile +++ /dev/null @@ -1,28 +0,0 @@ -selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137 - -causal-conv1d: - rm -rf causal-conv1d - git clone https://github.com/Dao-AILab/causal-conv1d.git - -build-causal-conv1d: causal-conv1d - cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag - cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build - -install-causal-conv1d: build-causal-conv1d - pip uninstall causal-conv1d -y || true - cd causal-conv1d/ && pip install . - -# selective-scan dependends on causal-conv1d -selective-scan: - rm -rf mamba - git clone https://github.com/state-spaces/mamba.git mamba - -build-selective-scan: selective-scan - cd mamba/ && git fetch && git checkout $(selective_scan_commit) - cd mamba && python setup.py build - -install-selective-scan: install-causal-conv1d build-selective-scan - pip uninstall selective-scan-cuda -y || true - cd mamba && pip install . - -build-all: build-causal-conv1d build-selective-scan \ No newline at end of file diff --git a/src/transformers/models/mamba/modeling_mamba.py b/src/transformers/models/mamba/modeling_mamba.py index 228cb1ddcc207f..490020531ce53e 100644 --- a/src/transformers/models/mamba/modeling_mamba.py +++ b/src/transformers/models/mamba/modeling_mamba.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2024 Tri Dao, Albert Gu and HuggingFace Inc. team. +# Copyright 2024 state-spaces/mamba org and HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -42,17 +42,11 @@ from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn from mamba_ssm.ops.triton.selective_state_update import selective_state_update else: - logger.warning_once( - "The `mamba_ssm` package is not installed in your environnement. Make sure to install it if you want to use the custom cuda kernels" - ) selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None if is_causal_conv1d_available(): from causal_conv1d import causal_conv1d_fn, causal_conv1d_update else: - logger.warning_once( - "The `causal_conv1d` package is not installed in your environnement. Make sure to install it if you want to use the custom cuda kernels" - ) causal_conv1d_update, causal_conv1d_fn = None, None is_fast_path_available = all( @@ -67,13 +61,10 @@ class MambaMixer(nn.Module): """ - Selective layer TODO DOC DOC DOC - Compute ∆ A B C D, the state space parameters. - A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) - ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, - and is why Mamba is called **selective** state spaces) - - ∆, B and C are the `selective` parameters + Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`. + A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) + ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, + and is why Mamba is called **selective** state spaces) """ def __init__(self, config, layer_idx): @@ -114,6 +105,13 @@ def __init__(self, config, layer_idx): self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias) self.use_bias = config.use_bias + if not is_fast_path_available: + logger.warning_once( + "The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)`" + " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and" + " https://github.com/Dao-AILab/causal-conv1d" + ) + def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params=None): # 1. Gated MLP's linear projection projected_states = self.in_proj(hidden_states).transpose(1, 2) @@ -222,17 +220,17 @@ def slow_forward(self, input_states, cache_params=None): hidden_states = torch.sum(conv_state * self.conv1d.weight[:, 0, :], dim=-1) if self.use_conv_bias: hidden_states += self.conv1d.bias - hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding + hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # [batch, intermediate_size, 1] : decoding else: conv_state = nn.functional.pad( - hidden_states, + hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0) ) cache_params.conv_states[self.layer_idx].copy_(conv_state) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] else: ssm_state = torch.zeros( - (batch_size, self.intermediate_size, self.ssm_state_size), + (batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype ) hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len] @@ -530,7 +528,7 @@ def forward( use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if (input_ids is None) ^ (inputs_embeds is not None): + if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index e9e247954f4360..7b0a00051b33fd 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -39,7 +39,6 @@ MambaModel, ) from transformers.models.mamba.modeling_mamba import MambaCache - from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_0 else: is_torch_greater_or_equal_than_2_0 = False From 2ddd9aadf50dc6c01ac34c55ac49a6d88a33167c Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Tue, 5 Mar 2024 10:27:58 +0900 Subject: [PATCH 112/116] fix base model prefix! --- docs/source/en/model_doc/mamba.md | 14 +++++++------- src/transformers/models/mamba/modeling_mamba.py | 2 +- tests/models/mamba/test_modeling_mamba.py | 7 +++---- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/docs/source/en/model_doc/mamba.md b/docs/source/en/model_doc/mamba.md index 65b09c9f789030..bd7cb278ce2820 100644 --- a/docs/source/en/model_doc/mamba.md +++ b/docs/source/en/model_doc/mamba.md @@ -1,4 +1,4 @@ -