diff --git a/docs/source/en/index.md b/docs/source/en/index.md
index 967049d89cbe12..82cac2f2262a59 100644
--- a/docs/source/en/index.md
+++ b/docs/source/en/index.md
@@ -375,6 +375,7 @@ Flax), PyTorch, and/or TensorFlow.
| [YOLOS](model_doc/yolos) | ✅ | ❌ | ❌ |
| [YOSO](model_doc/yoso) | ✅ | ❌ | ❌ |
| [Zamba](model_doc/zamba) | ✅ | ❌ | ❌ |
+| [Zamba2](model_doc/zamba2) | ✅ | ❌ | ❌ |
| [ZoeDepth](model_doc/zoedepth) | ✅ | ❌ | ❌ |
diff --git a/docs/source/en/model_doc/zamba2.md b/docs/source/en/model_doc/zamba2.md
new file mode 100644
index 00000000000000..75333555d45e56
--- /dev/null
+++ b/docs/source/en/model_doc/zamba2.md
@@ -0,0 +1,93 @@
+
+# Zamba2
+
+Zamba2 is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights.
+
+This model was contributed by [pglo](https://huggingface.co/pglo).
+
+
+## Model details
+
+Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and were trained using next-token prediction. Zamba2 uses shared transformer layers after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B were pre-trained on 2T and 3T tokens, respectively.
+
+
+
+## Quick start
+
+
+### Presequities
+
+Zamba2 requires you use `transformers` version 4.46.0 or higher:
+```bash
+pip install transformers>=4.46.0
+```
+
+## Inference
+
+```python
+from transformers import AutoTokenizer, AutoModelForCausalLM
+import torch
+
+tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
+model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16)
+
+input_text = "What factors contributed to the fall of the Roman Empire?"
+input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
+
+outputs = model.generate(**input_ids, max_new_tokens=100)
+print(tokenizer.decode(outputs[0]))
+```
+
+
+## Model card
+
+The model cards can be found at:
+* [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B)
+* [Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
+* [Zamba2-7B](https://huggingface.co/Zyphra/Zamba2-7B)
+
+
+## Issues
+For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba2-7B/discussions)
+
+
+## License
+
+The model weights are open-sourced via an Apache 2.0 license.
+
+
+## Zamba2Config
+
+[[autodoc]] Zamba2Config
+
+
+## Zamba2Model
+
+[[autodoc]] Zamba2Model
+ - forward
+
+
+## Zamba2ForCausalLM
+
+[[autodoc]] Zamba2ForCausalLM
+ - forward
+
+
+## Zamba2ForSequenceClassification
+
+[[autodoc]] transformers.Zamba2ForSequenceClassification
+ - forward
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index 930f41b6fefba7..d063498d9161c0 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -106,6 +106,7 @@ FlashAttention-2 is currently supported for the following architectures:
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
+* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)
You can request to add FlashAttention-2 support for another model by opening a GitHub Issue or Pull Request.
@@ -317,7 +318,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
-
+* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)
FlashAttention can only be used for models with the `fp16` or `bf16` torch type, so make sure to cast your model to the appropriate type first. The memory-efficient attention backend is able to handle `fp32` models.
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 600d3d217fa8a9..bd6b22ac27edf2 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -871,6 +871,7 @@
"models.yolos": ["YolosConfig"],
"models.yoso": ["YosoConfig"],
"models.zamba": ["ZambaConfig"],
+ "models.zamba2": ["Zamba2Config"],
"models.zoedepth": ["ZoeDepthConfig"],
"onnx": [],
"pipelines": [
@@ -3889,6 +3890,14 @@
"ZambaPreTrainedModel",
]
)
+ _import_structure["models.zamba2"].extend(
+ [
+ "Zamba2ForCausalLM",
+ "Zamba2ForSequenceClassification",
+ "Zamba2Model",
+ "Zamba2PreTrainedModel",
+ ]
+ )
_import_structure["models.zoedepth"].extend(
[
"ZoeDepthForDepthEstimation",
@@ -5881,6 +5890,7 @@
from .models.yolos import YolosConfig
from .models.yoso import YosoConfig
from .models.zamba import ZambaConfig
+ from .models.zamba2 import Zamba2Config
from .models.zoedepth import ZoeDepthConfig
# Pipelines
@@ -8361,6 +8371,12 @@
ZambaModel,
ZambaPreTrainedModel,
)
+ from .models.zamba2 import (
+ Zamba2ForCausalLM,
+ Zamba2ForSequenceClassification,
+ Zamba2Model,
+ Zamba2PreTrainedModel,
+ )
from .models.zoedepth import (
ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel,
diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py
index 7fcaddde704cf7..34b0906362976e 100644
--- a/src/transformers/models/__init__.py
+++ b/src/transformers/models/__init__.py
@@ -293,5 +293,6 @@
yolos,
yoso,
zamba,
+ zamba2,
zoedepth,
)
diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py
index 69ce8efa10c76c..1a21b32e064c67 100644
--- a/src/transformers/models/auto/configuration_auto.py
+++ b/src/transformers/models/auto/configuration_auto.py
@@ -324,6 +324,7 @@
("yolos", "YolosConfig"),
("yoso", "YosoConfig"),
("zamba", "ZambaConfig"),
+ ("zamba2", "Zamba2Config"),
("zoedepth", "ZoeDepthConfig"),
]
)
@@ -658,6 +659,7 @@
("yolos", "YOLOS"),
("yoso", "YOSO"),
("zamba", "Zamba"),
+ ("zamba2", "Zamba2"),
("zoedepth", "ZoeDepth"),
]
)
diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py
index e8a2dece432476..a164bbf316248c 100644
--- a/src/transformers/models/auto/modeling_auto.py
+++ b/src/transformers/models/auto/modeling_auto.py
@@ -296,6 +296,7 @@
("yolos", "YolosModel"),
("yoso", "YosoModel"),
("zamba", "ZambaModel"),
+ ("zamba2", "Zamba2Model"),
]
)
@@ -566,6 +567,7 @@
("xlnet", "XLNetLMHeadModel"),
("xmod", "XmodForCausalLM"),
("zamba", "ZambaForCausalLM"),
+ ("zamba2", "Zamba2ForCausalLM"),
]
)
@@ -1035,6 +1037,7 @@
("xmod", "XmodForSequenceClassification"),
("yoso", "YosoForSequenceClassification"),
("zamba", "ZambaForSequenceClassification"),
+ ("zamba2", "Zamba2ForSequenceClassification"),
]
)
diff --git a/src/transformers/models/auto/tokenization_auto.py b/src/transformers/models/auto/tokenization_auto.py
index 350c230f142c15..6b231814ed1838 100644
--- a/src/transformers/models/auto/tokenization_auto.py
+++ b/src/transformers/models/auto/tokenization_auto.py
@@ -571,6 +571,13 @@
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
+ (
+ "zamba2",
+ (
+ "LlamaTokenizer" if is_sentencepiece_available() else None,
+ "LlamaTokenizerFast" if is_tokenizers_available() else None,
+ ),
+ ),
]
)
diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py
index 3b7348eadd4785..0b2aea1370e94c 100644
--- a/src/transformers/models/zamba/modeling_zamba.py
+++ b/src/transformers/models/zamba/modeling_zamba.py
@@ -113,7 +113,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
-class HybridMambaAttentionDynamicCache(DynamicCache):
+class ZambaHybridDynamicCache(DynamicCache):
"""
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len).
@@ -131,9 +131,9 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self.dtype = dtype
self.layers_block_type = config.layers_block_type
self.has_previous_state = False # only used by mamba
- intermediate_size = config.mamba_expand * config.hidden_size
- ssm_state_size = config.mamba_d_state
- conv_kernel_size = config.mamba_d_conv
+ self.intermediate_size = config.mamba_expand * config.hidden_size
+ self.ssm_state_size = config.mamba_d_state
+ self.conv_kernel_size = config.mamba_d_conv
self.n_mamba_heads = config.n_mamba_heads
self.conv_states = []
self.ssm_states = []
@@ -143,9 +143,14 @@ def __init__(self, config, batch_size, dtype=torch.float16, device=None):
self._buffers = {}
for i in range(config.num_hidden_layers):
self.conv_states += [
- torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype)
+ torch.zeros(batch_size, self.intermediate_size, self.conv_kernel_size, device=device, dtype=dtype)
]
- cache_shape = (batch_size, self.n_mamba_heads, intermediate_size // self.n_mamba_heads, ssm_state_size)
+ cache_shape = (
+ batch_size,
+ self.n_mamba_heads,
+ self.intermediate_size // self.n_mamba_heads,
+ self.ssm_state_size,
+ )
self.ssm_states += [torch.zeros(cache_shape, device=device, dtype=dtype)]
if self.layers_block_type[i] == "hybrid":
self.transformer_layers.append(i)
@@ -194,14 +199,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return 0
return self.key_cache[layer_idx].shape[-2]
- # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.to_legacy_cache
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
- raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
+ raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")
@classmethod
- # Copied from transformers.models.jamba.modeling_jamba.HybridMambaAttentionDynamicCache.from_legacy_cache
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
- raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
+ raise NotImplementedError("ZambaHybridDynamicCache does not have a legacy cache equivalent.")
class ZambaAttention(nn.Module):
@@ -249,7 +252,7 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
@@ -326,7 +329,7 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
@@ -416,7 +419,7 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
@@ -568,7 +571,7 @@ def __init__(self, config: ZambaConfig, layer_idx):
)
def cuda_kernels_forward(
- self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None
+ self, hidden_states: torch.Tensor, cache_params: ZambaHybridDynamicCache = None, attention_mask=None
):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1
@@ -664,7 +667,7 @@ def cuda_kernels_forward(
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states
- def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None):
+ def slow_forward(self, input_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated linear projection
@@ -675,7 +678,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
gate = gate.squeeze(2)
gate = gate.reshape(batch_size, self.n_mamba_heads, -1, seq_len).transpose(0, 1)
- use_cache = isinstance(cache_params, HybridMambaAttentionDynamicCache)
+ use_cache = isinstance(cache_params, ZambaHybridDynamicCache)
# 2. Convolution sequence transformation
if use_cache and cache_params.ssm_states[self.layer_idx].shape[0] == batch_size:
if self.training:
@@ -757,7 +760,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
)
return contextualized_states
- def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask=None):
+ def forward(self, hidden_states, cache_params: ZambaHybridDynamicCache = None, attention_mask=None):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type:
raise ValueError(
@@ -802,7 +805,7 @@ def forward(
layer_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
@@ -817,7 +820,7 @@ def forward(
(see fig. 2 in https://arxiv.org/pdf/2405.16712).
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
- past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
+ past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@@ -870,7 +873,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
@@ -881,7 +884,7 @@ def forward(
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
- past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
+ past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@@ -923,7 +926,7 @@ def forward(
return outputs
-class HybridLayer(nn.Module):
+class ZambaHybridLayer(nn.Module):
def __init__(self, shared_transf: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer):
super().__init__()
self.shared_transf = shared_transf
@@ -938,7 +941,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
causal_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ past_key_value: Optional[ZambaHybridDynamicCache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
@@ -951,7 +954,7 @@ def forward(
layer_idx (`int`): layer number.
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
- past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states
+ past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
@@ -1027,7 +1030,7 @@ class ZambaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
_supports_sdpa = False
- _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
+ _supports_cache_class = True # Note: only supports ZambaHybridDynamicCache
_is_stateful = True
def _init_weights(self, module):
@@ -1121,14 +1124,14 @@ def _check_and_enable_flash_attn_2(
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
- past_key_values (`HybridMambaAttentionDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
- A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
+ past_key_values (`ZambaHybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ A ZambaHybridDynamicCache object containing pre-computed hidden-states (keys and values in the
self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
`(batch_size, d_inner, d_state)` respectively.
- See the `HybridMambaAttentionDynamicCache` class for more details.
+ See the `ZambaHybridDynamicCache` class for more details.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
@@ -1202,7 +1205,7 @@ def __init__(self, config: ZambaConfig):
"shared_transf.pre_ff_layernorm.weight",
]
self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
- layers.append(HybridLayer(block, next(linear_layers), next(mamba_layers)))
+ layers.append(ZambaHybridLayer(block, next(linear_layers), next(mamba_layers)))
else:
layers.append(next(mamba_layers))
self.layers = nn.ModuleList(layers)
@@ -1226,7 +1229,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
+ past_key_values: Optional[ZambaHybridDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
@@ -1263,7 +1266,7 @@ def forward(
if use_cache and past_key_values is None:
logger.warning_once(
- "Zamba requires an initialized `HybridMambaAttentionDynamicCache` to return a cache. None was "
+ "Zamba requires an initialized `ZambaHybridDynamicCache` to return a cache. None was "
"provided, so no cache will be returned."
)
@@ -1410,7 +1413,7 @@ def forward(
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[HybridMambaAttentionDynamicCache] = None,
+ past_key_values: Optional[ZambaHybridDynamicCache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@@ -1504,7 +1507,7 @@ def prepare_inputs_for_generation(
use_cache=True,
**kwargs,
):
- # Overwitten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
+ # Overwitten -- has a unique cache type, `ZambaHybridDynamicCache`
empty_past_kv = past_key_values is None
@@ -1518,7 +1521,7 @@ def prepare_inputs_for_generation(
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
else:
- past_key_values = HybridMambaAttentionDynamicCache(
+ past_key_values = ZambaHybridDynamicCache(
self.config, input_ids.shape[0], dtype=self.dtype, device=self.device
)
diff --git a/src/transformers/models/zamba2/__init__.py b/src/transformers/models/zamba2/__init__.py
new file mode 100644
index 00000000000000..00db458c72ebd5
--- /dev/null
+++ b/src/transformers/models/zamba2/__init__.py
@@ -0,0 +1,27 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING
+
+from ...utils import _LazyModule
+from ...utils.import_utils import define_import_structure
+
+
+if TYPE_CHECKING:
+ from .configuration_zamba2 import *
+ from .modeling_zamba2 import *
+else:
+ import sys
+
+ _file = globals()["__file__"]
+ sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
diff --git a/src/transformers/models/zamba2/configuration_zamba2.py b/src/transformers/models/zamba2/configuration_zamba2.py
new file mode 100644
index 00000000000000..01f1a8f7776edf
--- /dev/null
+++ b/src/transformers/models/zamba2/configuration_zamba2.py
@@ -0,0 +1,239 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/zamba2/modular_zamba2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_zamba2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...configuration_utils import PretrainedConfig
+
+
+class Zamba2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a
+ Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Zamba2 model.
+
+ [Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Zamba2Model`]
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with.
+ hidden_size (`int`, *optional*, defaults to 2560):
+ Dimension of the hidden representations.
+ num_hidden_layers (`int`, *optional*, defaults to 54):
+ Number of hidden layers in the model.
+ layers_block_type (`list`, *optional*):
+ List of layer types, which can be either "mamba" or "hybrid".
+ mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents.
+ mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel.
+ mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
+ mamba_ngroups (`int`, *optional*, defaults to 1):
+ Number of groups for the evolution matrices of mamba 2.
+ time_step_min (`float`, *optional*, defaults to 0.001):
+ Minimum `time_step` used to bound `dt_proj.bias`.
+ time_step_max (`float`, *optional*, defaults to 0.1):
+ Maximum `time_step` used to bound `dt_proj.bias`.
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
+ time_step_limit (`tuple`, *optional*):
+ Accepted range of time step values.
+ n_mamba_heads (`int`, *optional*, defaults to 8):
+ Number of heads for the evolution matrices of mamba 2.
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use bias in the convolution layer of the mixer block.
+ chunk_size (`int`, *optional*, defaults to 256):
+ Size of the chunks that will comprise the sequence.
+ add_bias_linear (`bool`, *optional*, defaults to `False`):
+ Flag indicating whether or not to use bias in various layers
+ intermediate_size (`int`, *optional*, defaults to 4 * hidden_size):
+ Dimension of the MLP representations.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the MLP.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf).
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_mem_blocks (`int`, *optional*, defaults to 1):
+ Number of unshared transformer blocks.
+ use_shared_mlp_adapter (`bool`, *optional*, defaults to `False`):
+ If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the shared MLP's.
+ use_shared_attention_adapter (`bool`, *optional*, defaults to `False`):
+ If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers.
+ adapter_rank (`int`, *optional*, defaults to 128):
+ Rank of the adapter in the shared MLP and shared attention layers.
+ use_mem_rope (`bool`, *optional*, defaults to `False`):
+ If True, includes RoPE in the shared attention layers.
+ rope_theta (`float`, *optional*, defaults to `10000.0`):
+ The base period of the RoPE embeddings.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
+ integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
+ logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
+ sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
+ significantly.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ use_long_context (`bool`, *optional*, defaults to `False`):
+ Activates the context-extended version of Zamba by modifying RoPE.
+ ```python
+ >>> from transformers import Zamba2Model, Zamba2Config
+ >>> # Initializing a Zamba2-2.7B style configuration
+ >>> configuration = Zamba2Config()
+ >>> # Initializing a model from the Zamba2-2.7B style configuration
+ >>> model = Zamba2Model(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ """
+
+ model_type = "zamba2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ max_position_embeddings=4096,
+ hidden_size=2560,
+ num_hidden_layers=54,
+ layers_block_type=None,
+ mamba_d_state=64,
+ mamba_d_conv=4,
+ mamba_expand=2,
+ mamba_ngroups=1,
+ time_step_min=0.001,
+ time_step_max=0.1,
+ time_step_floor=1e-4,
+ time_step_limit=None,
+ n_mamba_heads=8,
+ use_conv_bias=True,
+ chunk_size=256,
+ add_bias_linear=False,
+ intermediate_size=None,
+ hidden_act="gelu",
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ attention_dropout=0.0,
+ num_mem_blocks=1,
+ use_shared_mlp_adapter=False,
+ use_shared_attention_adapter=False,
+ adapter_rank=128,
+ use_mem_rope=False,
+ rope_theta=10000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ num_logits_to_keep=1,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ use_long_context=False,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ if intermediate_size is None:
+ self.intermediate_size = 4 * hidden_size
+ else:
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_mem_blocks = num_mem_blocks
+ self.attention_hidden_size = 2 * hidden_size
+ self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads
+ self.attention_dropout = attention_dropout
+ self.use_mem_rope = use_mem_rope
+ self.use_long_context = use_long_context
+ if use_mem_rope and use_long_context:
+ a = 8
+ rope_theta = rope_theta * a ** (self.attention_head_dim / (self.attention_head_dim - 2))
+ self.rope_theta = rope_theta
+ self.mamba_d_state = mamba_d_state
+ self.mamba_d_conv = mamba_d_conv
+ self.mamba_expand = mamba_expand
+ self.add_bias_linear = add_bias_linear
+ self.mamba_ngroups = mamba_ngroups
+ self.n_mamba_heads = n_mamba_heads
+ self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads
+ self.use_conv_bias = use_conv_bias
+ self.chunk_size = chunk_size
+ self.time_step_limit = time_step_limit
+ self.use_shared_mlp_adapter = use_shared_mlp_adapter
+ self.use_shared_attention_adapter = use_shared_attention_adapter
+ self.adapter_rank = adapter_rank
+ self.time_step_min = time_step_min
+ self.time_step_max = time_step_max
+ self.time_step_floor = time_step_floor
+ if use_long_context:
+ self.max_position_embeddings = 16384
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.num_attention_heads = num_attention_heads
+ self.kv_channels = self.hidden_size // self.num_attention_heads
+ self.num_query_groups = self.num_attention_heads
+ # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer)
+ if layers_block_type is None:
+ self.layers_block_type = (
+ ["mamba"]
+ + (["mamba"] * 5 + ["hybrid"]) * 7
+ + ["mamba"] * 4
+ + ["hybrid"]
+ + ["mamba"] * 3
+ + ["hybrid"]
+ + ["mamba"] * 2
+ )
+ else:
+ self.layers_block_type = layers_block_type
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.num_logits_to_keep = num_logits_to_keep
+
+
+__all__ = ["Zamba2Config"]
diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py
new file mode 100644
index 00000000000000..caca525f879505
--- /dev/null
+++ b/src/transformers/models/zamba2/modeling_zamba2.py
@@ -0,0 +1,2229 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/zamba2/modular_zamba2.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_zamba2.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from itertools import cycle
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, DynamicCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import _flash_attention_forward
+from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
+from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+ replace_return_docstrings,
+)
+from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available
+from .configuration_zamba2 import Zamba2Config
+
+
+if is_mamba_ssm_available():
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
+else:
+ selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None
+
+if is_causal_conv1d_available():
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+else:
+ causal_conv1d_update, causal_conv1d_fn = None, None
+
+
+logger = logging.get_logger(__name__)
+
+_CONFIG_FOR_DOC = "Zamba2Config"
+
+
+class Zamba2RMSNormGated(torch.nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states, gate=None):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+
+ if gate is not None:
+ hidden_states = hidden_states * nn.functional.silu(gate.to(torch.float32))
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+
+ return self.weight * hidden_states.to(input_dtype)
+
+
+class Zamba2RMSNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-6):
+ """
+ Zamba2RMSNorm is equivalent to T5LayerNorm
+ """
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, hidden_states):
+ input_dtype = hidden_states.dtype
+ hidden_states = hidden_states.to(torch.float32)
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
+ return self.weight * hidden_states.to(input_dtype)
+
+ def extra_repr(self):
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
+
+
+class Zamba2HybridDynamicCache(DynamicCache):
+ """
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
+ (which has a constant shape regardless of seq_len).
+
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
+ """
+
+ def __init__(
+ self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
+ ):
+ self.dtype = dtype
+ self.layers_block_type = config.layers_block_type
+ self.has_previous_state = False
+ self.intermediate_size = int(config.mamba_expand * config.hidden_size)
+ self.ssm_state_size = config.mamba_d_state
+ self.conv_kernel_size = config.mamba_d_conv
+ self.n_mamba_heads = config.n_mamba_heads
+ self.transformer_layers = []
+ self._modules = {}
+ self._parameters = {}
+ self._buffers = {}
+ self.conv_states = {
+ i: torch.zeros(
+ batch_size,
+ self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state,
+ self.conv_kernel_size,
+ device=device,
+ dtype=dtype,
+ )
+ for i in range(config.num_hidden_layers)
+ }
+ self.ssm_states = {
+ i: torch.zeros(
+ batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype
+ )
+ for i in range(config.num_hidden_layers)
+ }
+ for i in range(config.num_hidden_layers):
+ if self.layers_block_type[i] == "hybrid":
+ self.transformer_layers.append(i)
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
+
+ def update(
+ self,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ layer_idx: int,
+ cache_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Update the cache
+ if self.key_cache[layer_idx].shape[-1] == 0:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
+
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
+
+ def reorder_cache(self, beam_idx: torch.LongTensor):
+ """Reorders the cache for beam search, given the selected beam indices."""
+ for layer_idx in range(len(self.key_cache)):
+ device = self.key_cache[layer_idx].device
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+ device = self.value_cache[layer_idx].device
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+ device = self.conv_states[layer_idx].device
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
+ device = self.ssm_states[layer_idx].device
+ self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
+
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
+ # take any layer that contains cache and not empty tensor
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
+ if len(self.key_cache) <= layer_idx:
+ return 0
+ return self.key_cache[layer_idx].shape[-2]
+
+ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
+ raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.")
+
+ @classmethod
+ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
+ raise NotImplementedError("Zamba2HybridDynamicCache does not have a legacy cache equivalent.")
+
+ def update_conv_state(
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
+ ) -> torch.Tensor:
+ conv_state = self.conv_states[layer_idx]
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
+
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
+ self.conv_states[layer_idx].zero_()
+ self.conv_states[layer_idx] += conv_state
+ return self.conv_states[layer_idx]
+
+ def reset(self):
+ self.conv_states.zero_()
+ self.ssm_states.zero_()
+
+
+class Zamba2RotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ config: Zamba2Config,
+ device=None,
+ ):
+ super().__init__()
+ self.rope_kwargs = {"base": config.rope_theta, "dim": config.attention_head_dim}
+ # BC: "rope_type" was originally "type"
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
+ else:
+ self.rope_type = "default"
+ self.max_seq_len_cached = config.max_position_embeddings
+ self.original_max_seq_len = config.max_position_embeddings
+
+ self.config = config
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
+ inv_freq, self.attention_scaling = self.rope_init_fn(config=None, device=device, **self.rope_kwargs)
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+ self.original_inv_freq = self.inv_freq
+
+ def _dynamic_frequency_update(self, position_ids, device):
+ """
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
+ 1 - growing beyond the cached sequence length (allow scaling)
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
+ """
+ seq_len = torch.max(position_ids) + 1
+ if seq_len > self.max_seq_len_cached: # growth
+ inv_freq, self.attention_scaling = self.rope_init_fn(
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
+ self.max_seq_len_cached = seq_len
+
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
+ self.max_seq_len_cached = self.original_max_seq_len
+
+ @torch.no_grad()
+ def forward(self, x, position_ids):
+ if "dynamic" in self.rope_type:
+ self._dynamic_frequency_update(position_ids, device=x.device)
+
+ # Core RoPE block
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
+ position_ids_expanded = position_ids[:, None, :].float()
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
+ device_type = x.device.type
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
+ with torch.autocast(device_type=device_type, enabled=False):
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos = emb.cos()
+ sin = emb.sin()
+
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
+ cos = cos * self.attention_scaling
+ sin = sin * self.attention_scaling
+
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
+
+
+def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
+ """
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
+ if n_rep == 1:
+ return hidden_states
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
+
+
+def layer_type_list(config: Zamba2Config):
+ """
+ Returns list of layer ids containing hybrid layers
+ """
+ output_list = []
+ for index, type in enumerate(config.layers_block_type):
+ if type == "hybrid":
+ output_list.append(index)
+ return output_list
+
+
+class Zamba2Attention(nn.Module):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
+ and "Generating Long Sequences with Sparse Transformers".
+
+ Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
+ The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
+ The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
+ (see fig. 2 in https://arxiv.org/pdf/2405.16712).
+ Additionally, replaced
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
+
+ Multi-headed attention from 'Attention Is All You Need' paper.
+
+ Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
+ The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
+ The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
+ (see fig. 2 in https://arxiv.org/pdf/2405.16712).
+ Additionally, replaced
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
+ Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
+ layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
+ expressivity with a small memory overhead (see Fig. 2 of https://arxiv.org/pdf/2411.15242).
+ """
+
+ def __init__(
+ self,
+ config: Zamba2Config,
+ layer_idx: Optional[int] = None,
+ num_fwd_mem_blocks: int = None,
+ block_id: int = None,
+ ):
+ super().__init__()
+ self.config = config
+ self.layer_idx = layer_idx
+
+ self.hidden_size = config.hidden_size
+ self.attention_hidden_size = config.attention_hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = config.attention_head_dim
+ self.num_key_value_heads = config.num_key_value_heads
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
+ self.max_position_embeddings = config.max_position_embeddings
+ self.is_causal = True
+ self.attention_dropout = config.attention_dropout
+
+ if (self.head_dim * self.num_heads) != self.attention_hidden_size:
+ raise ValueError(
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
+ f" and `num_heads`: {self.num_heads})."
+ )
+ self.q_proj = nn.Linear(self.attention_hidden_size, self.num_heads * self.head_dim, bias=False)
+ self.k_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.v_proj = nn.Linear(self.attention_hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
+ self.num_fwd_mem_blocks = num_fwd_mem_blocks
+ self.layer_block_map = layer_type_list(config)
+ self.block_id = block_id
+
+ if config.use_shared_attention_adapter:
+ self.linear_q_adapter_list = nn.ModuleList([])
+ self.linear_k_adapter_list = nn.ModuleList([])
+ self.linear_v_adapter_list = nn.ModuleList([])
+
+ for i in range(self.num_fwd_mem_blocks):
+ if i % config.num_mem_blocks == block_id:
+ linear_q_adapter = nn.Sequential(
+ nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
+ nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
+ )
+ linear_k_adapter = nn.Sequential(
+ nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
+ nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
+ )
+ linear_v_adapter = nn.Sequential(
+ nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
+ nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
+ )
+ else:
+ linear_q_adapter = nn.Identity()
+ linear_k_adapter = nn.Identity()
+ linear_v_adapter = nn.Identity()
+ self.linear_q_adapter_list.append(linear_q_adapter)
+ self.linear_k_adapter_list.append(linear_k_adapter)
+ self.linear_v_adapter_list.append(linear_v_adapter)
+
+ self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)}
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_idx: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ if self.config.use_shared_attention_adapter:
+ adapter_layer_idx = self.layer_dic[layer_idx]
+ query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states)
+ key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states)
+ value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.config.use_mem_rope:
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+def rotate_half(x):
+ """Rotates half the hidden dims of the input."""
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
+ """Applies Rotary Position Embedding to the query and key tensors.
+
+ Args:
+ q (`torch.Tensor`): The query tensor.
+ k (`torch.Tensor`): The key tensor.
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
+ position_ids (`torch.Tensor`, *optional*):
+ Deprecated and unused.
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
+ Returns:
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
+ """
+ cos = cos.unsqueeze(unsqueeze_dim)
+ sin = sin.unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
+# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward
+# dropped use_sliding_windows from the arguments of self._flash_attention_forward
+class Zamba2FlashAttention2(Zamba2Attention):
+ """
+ Zamba2 flash attention module. This module inherits from `Zamba2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_idx: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ if self.config.use_shared_attention_adapter:
+ adapter_layer_idx = self.layer_dic[layer_idx]
+ query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states)
+ key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states)
+ value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.config.use_mem_rope:
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ softmax_scale = 1 / math.sqrt(self.head_dim / 2)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ softmax_scale=softmax_scale,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
+# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention
+class Zamba2SdpaAttention(Zamba2Attention):
+ """
+ Zamba2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Zamba2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_idx: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Zamba2Model is using Zamba2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ if self.config.use_shared_attention_adapter:
+ adapter_layer_idx = self.layer_dic[layer_idx]
+ query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states)
+ key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states)
+ value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.config.use_mem_rope:
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ softmax_scale = 1 / math.sqrt(self.head_dim / 2)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
+ scale=softmax_scale,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+# Helper methods for segment sum computation
+
+
+def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
+ """
+ Padding x tensor with `pad_size` on the seq_len dim (dim=1)
+
+ Assumes that we only have tensors of either size 4 or 3
+ """
+ pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
+
+ return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
+
+
+def reshape_into_chunks(input_tensor, pad_size, chunk_size):
+ """
+ Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
+ simultaneously splitting it into chunk sequences.
+
+ Assumes that we only have tensors of either size 4 or 3
+ """
+ # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
+ input_tensor = pad_tensor_by_size(input_tensor, pad_size)
+
+ if len(input_tensor.shape) == 3:
+ # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
+ return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
+ else:
+ # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
+ return input_tensor.reshape(
+ input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
+ )
+
+
+def segment_sum(input_tensor):
+ """
+ More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
+ """
+ chunk_size = input_tensor.size(-1)
+ # 1. expand input tensor to have an additional dimension and repeat along that dimension
+ # [..., chunk_size] -> [..., chunk_size, chunk_size]
+ input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
+ # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
+ input_tensor = input_tensor.masked_fill(~mask, 0)
+ # 3. compute actual cumsum
+ tensor_segsum = torch.cumsum(input_tensor, dim=-2)
+
+ # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
+ tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
+ return tensor_segsum
+
+
+is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
+
+
+class Zamba2MambaMixer(nn.Module):
+ """
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
+ and is why Mamba is called **selective** state spaces)
+ """
+
+ def __init__(self, config: Zamba2Config, layer_idx: int = None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.ssm_state_size = config.mamba_d_state
+ self.conv_kernel_size = config.mamba_d_conv
+ self.intermediate_size = int(config.mamba_expand * self.hidden_size)
+ self.layer_idx = layer_idx
+ self.use_conv_bias = config.use_conv_bias
+ self.activation = "silu"
+ self.act = nn.SiLU()
+
+ self.n_groups = config.mamba_ngroups
+ self.head_dim = config.mamba_headdim
+ self.num_heads = self.config.n_mamba_heads
+ self.chunk_size = config.chunk_size
+
+ self.time_step_limit = config.time_step_limit
+ self.time_step_min = config.time_step_min
+ self.time_step_max = config.time_step_max
+
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
+ self.conv1d = nn.Conv1d(
+ in_channels=self.conv_dim,
+ out_channels=self.conv_dim,
+ bias=True,
+ kernel_size=config.mamba_d_conv,
+ groups=self.conv_dim,
+ padding=config.mamba_d_conv - 1,
+ )
+
+ # projection of the input hidden states
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
+ self.in_proj = nn.Linear(
+ self.hidden_size,
+ projection_size,
+ bias=config.add_bias_linear,
+ )
+ # selective projection used to make dt, B and C input dependant
+
+ # time step projection (discretization)
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
+
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, self.num_heads + 1)
+ self.A_log = nn.Parameter(torch.log(A))
+ self.A_log._no_weight_decay = True
+ self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5)
+ self.D = nn.Parameter(torch.ones(self.num_heads))
+ self.D._no_weight_decay = True
+
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)
+
+ if not is_fast_path_available:
+ logger.warning_once(
+ "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
+ " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
+ " https://github.com/Dao-AILab/causal-conv1d"
+ )
+
+ def cuda_kernels_forward(
+ self,
+ hidden_states: torch.Tensor,
+ cache_params: Optional[Zamba2HybridDynamicCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ # set up dimensions for reshapes later
+
+ batch_size, seq_len, _ = hidden_states.shape
+ groups_time_state_size = self.n_groups * self.ssm_state_size
+ d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
+
+ # getting projected states from cache if it exists
+ if cache_params is not None and cache_params.has_previous_state:
+ in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
+ d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
+ split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
+ _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
+
+ hidden_states_B_C = causal_conv1d_update(
+ hidden_states_B_C,
+ cache_params.conv_states[self.layer_idx],
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.activation,
+ )
+
+ hidden_states, B, C = torch.split(
+ hidden_states_B_C,
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
+ dim=-1,
+ )
+ A = -torch.exp(self.A_log.float()) # (nheads,)
+
+ A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
+ B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
+ C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
+ hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
+ hidden_states = selective_state_update(
+ cache_params.ssm_states[self.layer_idx],
+ hidden_states_reshaped,
+ dt,
+ A,
+ B,
+ C,
+ D,
+ z=None,
+ dt_bias=dt_bias,
+ dt_softplus=True,
+ )
+ hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
+ hidden_states = self.norm(hidden_states, gate)
+ out = self.out_proj(hidden_states)[:, None, ...]
+ # if no cache is found, calling the kernel
+ else:
+ if attention_mask is not None and not torch.all(attention_mask == 1):
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ dtype = hidden_states.dtype
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(hidden_states)
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
+ dt_limit_kwargs = {} if self.time_step_limit is None else {"dt_limit": self.time_step_limit}
+ if attention_mask is not None:
+ input_not_masked = torch.all(attention_mask == 1)
+ else:
+ input_not_masked = True
+
+ if self.training and cache_params is None and input_not_masked:
+ out, ssm_state = mamba_split_conv1d_scan_combined(
+ projected_states,
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.dt_bias,
+ A,
+ D=self.D,
+ chunk_size=self.chunk_size,
+ seq_idx=None,
+ activation=self.activation,
+ rmsnorm_weight=self.norm.weight,
+ rmsnorm_eps=self.norm.variance_epsilon,
+ outproj_weight=self.out_proj.weight,
+ outproj_bias=self.out_proj.bias,
+ headdim=self.head_dim,
+ ngroups=self.n_groups,
+ norm_before_gate=False,
+ return_final_states=True,
+ **dt_limit_kwargs,
+ )
+
+ else:
+ gate, hidden_states_B_C, time_step = torch.split(
+ projected_states,
+ [self.intermediate_size, self.conv_dim, self.num_heads],
+ dim=-1,
+ )
+
+ # 1D Convolution
+ if cache_params is not None:
+ hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2)
+ conv_state = nn.functional.pad(
+ hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0)
+ )
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
+ hidden_states_B_C = self.act(
+ self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
+ else:
+ hidden_states_B_C = causal_conv1d_fn(
+ x=hidden_states_B_C.transpose(1, 2),
+ weight=self.conv1d.weight.squeeze(1),
+ bias=self.conv1d.bias,
+ activation=self.activation,
+ ).transpose(1, 2)[:, :seq_len]
+ hidden_states, B, C = torch.split(
+ hidden_states_B_C,
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
+ dim=-1,
+ )
+ if attention_mask is not None and not torch.all(attention_mask == 1):
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ dtype = hidden_states.dtype
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+ scan_output, ssm_state = mamba_chunk_scan_combined(
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
+ time_step,
+ A,
+ B.view(batch_size, seq_len, self.n_groups, -1),
+ C.view(batch_size, seq_len, self.n_groups, -1),
+ chunk_size=self.chunk_size,
+ D=self.D,
+ z=None,
+ seq_idx=None,
+ return_final_states=True,
+ dt_bias=self.dt_bias,
+ dt_softplus=True,
+ **dt_limit_kwargs,
+ )
+ if ssm_state is not None and cache_params is not None:
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
+ scan_output = scan_output.view(batch_size, seq_len, -1)
+ # Multiply "gate" branch and apply extra normalization layer
+ scan_output = self.norm(scan_output, gate)
+ out = self.out_proj(scan_output)
+ return out
+
+ # fmt: off
+ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
+ batch_size, seq_len, _ = input_states.shape
+ dtype = input_states.dtype
+ # Gated MLP's linear projection
+ if cache_params is not None and cache_params.has_previous_state:
+ projected_states = self.in_proj(input_states.squeeze(1))
+ else:
+ if attention_mask is not None and not torch.all(attention_mask==1):
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ input_states = (input_states * attention_mask[:, :, None]).to(dtype)
+ projected_states = self.in_proj(input_states)
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
+ _, _, gate, hidden_states, dt = projected_states.split(
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
+ )
+
+ # Convolution sequence transformation
+ if cache_params is not None:
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+ ssm_state = ssm_state.to(hidden_states.device)
+ if cache_params.has_previous_state:
+ gate = gate.unsqueeze(1)
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
+ # handle batched generation - states are copied through
+ conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
+ hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
+ if self.use_conv_bias:
+ hidden_states += self.conv1d.bias
+ hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
+ else:
+ hidden_states = hidden_states.transpose(1,2)
+ conv_state = nn.functional.pad(
+ hidden_states,
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
+ )
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
+ hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
+ if attention_mask is not None and not torch.all(attention_mask==1):
+ dtype = hidden_states.dtype
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+ else:
+ ssm_state = torch.zeros(
+ (batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
+ device=hidden_states.device, dtype=dtype
+ )
+ hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
+ hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
+ A = -torch.exp(self.A_log.float()) # [num_heads]
+ if cache_params is not None and cache_params.has_previous_state:
+ # Note: there is no need to pad parameter matrices here, as there is just one new token
+ # for batched generation
+ dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
+ dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
+ # [num_heads] -> [num_heads, head_dim]
+ dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
+
+ dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
+ dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
+ # [bsz, num_heads, head_dim, state_size]
+ dA = torch.exp(dt[..., None] * A)
+
+ # Discretize B
+ # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
+ # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
+ B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
+ B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
+ B = B.reshape(batch_size, -1, B.shape[-1])
+ # [bsz, num_heads, head_dim, state_size]
+ dB = dt[..., None] * B[..., None, :]
+
+ # Discretize x into dB
+ # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
+ hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
+ dBx = dB * hidden_states[..., None]
+
+ # State calculation
+ cache_params.ssm_states[self.layer_idx].copy_(
+ cache_params.ssm_states[self.layer_idx] * dA + dBx
+ )
+
+ # Subsequent output
+ # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
+ C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
+ C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
+ C = C.reshape(batch_size, -1, C.shape[-1])
+ # [bsz, num_heads, head_dim]
+
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
+ # Reshape ssm_states to merge the first two dimensions
+ ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
+ C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
+ y = torch.bmm(ssm_states_reshaped, C_reshaped)
+ y = y.view(batch_size, self.num_heads, self.head_dim)
+
+ # D skip connection
+ # [num_heads] -> [num_heads, head_dim]
+ D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
+ y = (y + hidden_states * D).to(y.dtype)
+
+ # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
+ y = y.reshape(batch_size, -1)[:, None, ...]
+ else:
+ # begin ssd naive implementation without einsums
+ dt = nn.functional.softplus(dt + self.dt_bias)
+ dt = torch.clamp(dt, self.time_step_min)
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
+ C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
+ B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
+ C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
+ pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
+
+ D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
+
+ # Discretize x and A
+ hidden_states = hidden_states * dt[..., None]
+ A = A.to(hidden_states.dtype) * dt
+
+ # Rearrange into blocks/chunks
+ hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
+
+
+ # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
+ A = A.permute(0, 3, 1, 2)
+ A_cumsum = torch.cumsum(A, dim=-1)
+
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
+ # This is the analog of a causal mask
+ L = torch.exp(segment_sum(A))
+
+ # First, contraction of C and B to get G (attention-weights like)
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
+ G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
+
+
+ # Step 2: Compute M, equivalent to applying attention mask to weights
+ M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
+ M = M_intermediate.sum(dim=-1)
+
+ # Step 3: Compute Y_diag (apply to values)
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
+
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
+
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
+ B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
+ # permute back B * decay states
+ states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
+ if cache_params is not None and cache_params.has_previous_state:
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
+ else:
+ previous_states = torch.zeros_like(states[:, :1])
+ states = torch.cat([previous_states, states], dim=1)
+ decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
+
+ states_permuted = states.permute(0, 2, 1, 3, 4)
+ result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
+ new_states = result.permute(0, 2, 1, 3, 4)
+ states, ssm_state = new_states[:, :-1], new_states[:, -1]
+
+ # Compute state -> output conversion per chunk
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
+ state_decay_out = torch.exp(A_cumsum)
+ # compute Yoff
+ C_times_states = (C[..., None, :] * states[:, :, None, ...])
+ state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
+ Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
+
+ y = Y_diag + Y_off
+ # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
+ y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
+
+ y = y + D_residual
+ # Cutting off padded chunks
+ if pad_size > 0:
+ y = y[:, :seq_len, :, :]
+ y = y.reshape(batch_size, seq_len, -1)
+ if ssm_state is not None and cache_params is not None:
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
+
+ scan_output = self.norm(y, gate)
+
+ # end ssd naive
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
+ return contextualized_states
+ # fmt: on
+
+ def forward(
+ self,
+ hidden_states,
+ cache_params: Optional[Zamba2HybridDynamicCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
+
+ return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
+
+
+class Zamba2MLP(nn.Module):
+ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None):
+ """
+ This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
+ is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
+ """
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.intermediate_size = config.intermediate_size
+ self.act_fn = ACT2FN[config.hidden_act]
+ self.num_fwd_mem_blocks = num_fwd_mem_blocks
+ self.block_id = block_id
+
+ def gated_act_fn(x):
+ x = torch.chunk(x, 2, dim=-1)
+ return self.act_fn(x[0]) * x[1]
+
+ self.gated_act_fn = gated_act_fn
+ self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)
+
+ if self.config.use_shared_mlp_adapter:
+ self.gate_up_proj_adapter_list = nn.ModuleList([])
+ for i in range(self.num_fwd_mem_blocks):
+ if i % config.num_mem_blocks == block_id:
+ gate_up_proj_adapter = nn.Sequential(
+ nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False),
+ nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False),
+ )
+ else:
+ gate_up_proj_adapter = nn.Identity()
+ self.gate_up_proj_adapter_list.append(gate_up_proj_adapter)
+
+ layer_block_map = layer_type_list(config)
+ self.layer_dic = {value: index for index, value in enumerate(layer_block_map)}
+
+ def forward(self, hidden_state, layer_idx=None):
+ gate_up_state = self.gate_up_proj(hidden_state)
+ if self.config.use_shared_mlp_adapter:
+ layer_idx = self.layer_dic[layer_idx]
+ gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state)
+
+ hidden_state = self.gated_act_fn(gate_up_state)
+ output = self.down_proj(hidden_state)
+ return output
+
+
+def count_mem_blocks_in_config(config: Zamba2Config):
+ """
+ Count number of shared blocks
+ """
+ num_gs = 0
+ for val in config.layers_block_type:
+ if val == "hybrid":
+ num_gs += 1
+ return num_gs
+
+
+ZAMBA2_ATTENTION_CLASSES = {
+ "eager": Zamba2Attention,
+ "flash_attention_2": Zamba2FlashAttention2,
+ "sdpa": Zamba2SdpaAttention,
+}
+
+
+class Zamba2AttentionDecoderLayer(nn.Module):
+ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.block_id = block_id
+ num_gs = count_mem_blocks_in_config(config)
+ self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation](
+ config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id
+ )
+ self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id)
+ self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps)
+ self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ original_hidden_states: torch.Tensor,
+ layer_idx: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
+ original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
+ This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
+ concatenated tensor is then used as input of the pre-attention RMSNorm
+ (see fig. 2 in https://arxiv.org/pdf/2405.16712).
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ """
+ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1)
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ layer_idx=layer_idx,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.pre_ff_layernorm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states, layer_idx)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class Zamba2MambaDecoderLayer(nn.Module):
+ def __init__(self, config: Zamba2Config, layer_idx: int):
+ super().__init__()
+ self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx)
+ self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ self.layer_idx = layer_idx
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ original_hidden_states: Optional[torch.Tensor] = None,
+ layer_idx: int = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[torch.LongTensor] = None,
+ transformer_hidden_states: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ transformer_hidden_states (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Output of the previous shared transformer layer (if present) of shape `(batch_size, seq_len, embed_dim)`.
+ """
+
+ residual = hidden_states
+
+ # `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://arxiv.org/pdf/2405.16712).
+ # `transformer_hidden_states` is then added to the input to the mamba layer below (as described in eq. (6) of https://arxiv.org/pdf/2405.16712).
+ hidden_states = (
+ hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states
+ )
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states = self.mamba(
+ hidden_states=hidden_states,
+ cache_params=past_key_value,
+ attention_mask=attention_mask,
+ )
+
+ self_attn_weights = None
+
+ # residual connection after mamba
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (past_key_value,)
+
+ return outputs
+
+
+class Zamba2HybridLayer(nn.Module):
+ def __init__(
+ self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer
+ ):
+ super().__init__()
+ self.linear = linear
+ self.mamba_decoder = mamba
+ self.shared_transformer = shared_transformer
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ original_hidden_states: Optional[torch.Tensor] = None,
+ layer_idx: int = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
+ hidden activations to form the input of the shared transformer layer.
+ layer_idx (`int`): layer number.
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ """
+
+ layer_outputs = self.shared_transformer(
+ hidden_states,
+ original_hidden_states=original_hidden_states,
+ layer_idx=layer_idx,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ transformer_hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ self_attn_weights = layer_outputs[1]
+
+ transformer_hidden_states = self.linear(transformer_hidden_states)
+
+ layer_outputs = self.mamba_decoder(
+ hidden_states,
+ transformer_hidden_states=transformer_hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ if output_attentions:
+ layer_outputs = (layer_outputs[0], self_attn_weights) + layer_outputs[2:]
+
+ return layer_outputs
+
+
+ZAMBA2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Zamba2Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.",
+ ZAMBA2_START_DOCSTRING,
+)
+class Zamba2PreTrainedModel(PreTrainedModel):
+ config_class = Zamba2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache
+ _is_stateful = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, Zamba2MambaMixer):
+ module.A_log._no_weight_decay = True
+ module.D._no_weight_decay = True
+
+ dt = torch.exp(
+ torch.rand(self.config.n_mamba_heads)
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+ + math.log(self.config.time_step_min)
+ ).clamp(min=self.config.time_step_floor)
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
+
+ with torch.no_grad():
+ module.dt_bias.copy_(inv_dt)
+ module.dt_bias._no_reinit = True
+
+
+ZAMBA2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Zamba2HybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ A Zamba2HybridDynamicCache object containing pre-computed hidden-states (keys and values in the
+ self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
+ Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
+ `(batch_size, d_inner, d_state)` respectively.
+ See the `Zamba2HybridDynamicCache` class for more details.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.",
+ ZAMBA2_START_DOCSTRING,
+)
+class Zamba2Model(Zamba2PreTrainedModel):
+ """
+ Model consisting of *config.num_hidden_layers* layers.
+
+ Args:
+ config: Zamba2Config
+ """
+
+ def __init__(self, config: Zamba2Config):
+ super().__init__(config)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ blocks = [Zamba2AttentionDecoderLayer(config, block_id=k) for k in range(config.num_mem_blocks)]
+ mamba_layers = []
+ linear_layers = []
+ self.layers_block_type = config.layers_block_type
+ for i in range(config.num_hidden_layers):
+ if config.layers_block_type[i] == "mamba":
+ mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i))
+ elif config.layers_block_type[i] == "hybrid":
+ linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False))
+ mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i))
+ mamba_layers = iter(mamba_layers)
+ linear_layers = iter(linear_layers)
+ blocks = cycle(blocks)
+ layers = self.get_layers(blocks, linear_layers, mamba_layers)
+ self.layers = nn.ModuleList(layers)
+
+ self._attn_implementation = config._attn_implementation
+ self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ if config.use_mem_rope:
+ if config.use_long_context:
+ logger.warning_once(
+ "`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`."
+ )
+ self.rotary_emb = Zamba2RotaryEmbedding(config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Zamba2HybridDynamicCache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+
+ original_hidden_states = torch.clone(inputs_embeds)
+ # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer
+
+ if use_cache and past_key_values is None:
+ logger.warning_once(
+ "Zamba2 requires an initialized `Zamba2HybridDynamicCache` to return a cache. None was "
+ "provided, so no cache will be returned."
+ )
+
+ if cache_position is None:
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
+ # create position embeddings to be shared across the decoder layers
+ if self.config.use_mem_rope:
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ else:
+ position_embeddings = None
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for layer_idx, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ original_hidden_states,
+ layer_idx,
+ attention_mask,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ original_hidden_states=original_hidden_states,
+ layer_idx=layer_idx,
+ attention_mask=attention_mask,
+ causal_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ if layer_outputs[1] is not None:
+ # append attentions only of attention layers. Mamba layers return `None` as the attention weights
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.final_layernorm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if past_key_values and not past_key_values.has_previous_state:
+ past_key_values.has_previous_state = True
+
+ next_cache = None if not use_cache else past_key_values
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and 0.0 in attention_mask:
+ return attention_mask
+ return None
+
+ dtype, device = input_tensor.dtype, input_tensor.device
+ min_dtype = torch.finfo(dtype).min
+ sequence_length = input_tensor.shape[1]
+ target_length = cache_position[-1] + 1
+
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ if attention_mask.dim() == 2:
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type == "cuda"
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ def get_layers(self, blocks, linear_layers, mamba_layers):
+ layers = []
+ self._tied_weights_keys = []
+ for layer_id, layer_type in enumerate(self.layers_block_type):
+ if layer_type == "hybrid":
+ block = next(blocks)
+ if self.config.num_mem_blocks * len(layer_type_list(self.config)) > 1:
+ prefix_name = f"layers.{layer_id}."
+ tied_keys = [
+ "shared_transformer.self_attn.q_proj.weight",
+ "shared_transformer.self_attn.k_proj.weight",
+ "shared_transformer.self_attn.v_proj.weight",
+ "shared_transformer.self_attn.o_proj.weight",
+ "shared_transformer.feed_forward.gate_up_proj.weight",
+ "shared_transformer.feed_forward.down_proj.weight",
+ "shared_transformer.input_layernorm.weight",
+ "shared_transformer.pre_ff_layernorm.weight",
+ ]
+ self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
+ if self.config.use_shared_mlp_adapter:
+ tied_keys_adapter = []
+ adapter_id = 0
+ for _layer_type in self.layers_block_type:
+ if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
+ tied_keys_adapter.append(
+ "shared_transformer.feed_forward.gate_up_proj_adapter_list."
+ + str(adapter_id)
+ + ".0.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.feed_forward.gate_up_proj_adapter_list."
+ + str(adapter_id)
+ + ".1.weight"
+ )
+ adapter_id += 1
+ self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter]
+ if self.config.use_shared_attention_adapter:
+ tied_keys_adapter = []
+ adapter_id = 0
+ for _layer_type in self.layers_block_type:
+ if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_q_adapter_list."
+ + str(adapter_id)
+ + ".0.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_k_adapter_list."
+ + str(adapter_id)
+ + ".0.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_v_adapter_list."
+ + str(adapter_id)
+ + ".0.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_q_adapter_list."
+ + str(adapter_id)
+ + ".1.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_k_adapter_list."
+ + str(adapter_id)
+ + ".1.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_v_adapter_list."
+ + str(adapter_id)
+ + ".1.weight"
+ )
+ adapter_id += 1
+ self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter]
+ layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers)))
+ else:
+ layers.append(next(mamba_layers))
+ return layers
+
+
+# Adapted from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba2, JAMBA->ZAMBA2
+class Zamba2ForCausalLM(Zamba2PreTrainedModel, GenerationMixin):
+ def __init__(self, config: Zamba2Config):
+ super().__init__(config)
+ self.model = Zamba2Model(config)
+ self._tied_weights_keys = ["lm_head.weight", *self.model._tied_weights_keys]
+ self.vocab_size = config.vocab_size
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING)
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Zamba2HybridDynamicCache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ num_logits_to_keep: int = 0,
+ **loss_kwargs,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ num_logits_to_keep (`int` or `None`, *optional*):
+ Calculate logits for the last `num_logits_to_keep` tokens. If `None`, calculate logits for all
+ `input_ids`. Only last token logits are needed for generation, and calculating them only for that token
+ can save memory, which becomes pretty significant for long sequences.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Zamba2ForCausalLM
+
+ >>> model = Zamba2ForCausalLM.from_pretrained("Zyphra/Zamba2-7B-v1")
+ >>> tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B-v1")
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0]
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ position_ids=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ # Overwitten -- has a unique cache type, `Zamba2HybridDynamicCache`
+
+ empty_past_kv = past_key_values is None
+
+ # Omit tokens covered by past_key_values
+ if not empty_past_kv:
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
+ if inputs_embeds is not None: # Exception 1
+ input_ids = input_ids[:, -cache_position.shape[0] :]
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
+ input_ids = input_ids[:, cache_position]
+ else:
+ past_key_values = Zamba2HybridDynamicCache(
+ self.config, input_ids.shape[0], dtype=self.dtype, device=self.device
+ )
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if not empty_past_kv:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and empty_past_kv:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ "num_logits_to_keep": self.config.num_logits_to_keep,
+ "cache_position": cache_position,
+ }
+ )
+ return model_inputs
+
+
+@add_start_docstrings(
+ """
+ The Zamba2 Model with a sequence classification head on top (linear layer).
+
+ [`Zamba2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ ZAMBA2_START_DOCSTRING,
+)
+class Zamba2ForSequenceClassification(Zamba2PreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.model = Zamba2Model(config)
+ self._tied_weights_keys = self.model._tied_weights_keys
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.model.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.model.embed_tokens = value
+
+ @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+ if self.config.pad_token_id is None:
+ sequence_lengths = -1
+ else:
+ if input_ids is not None:
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
+ sequence_lengths = sequence_lengths.to(logits.device)
+ else:
+ sequence_lengths = -1
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
+
+ loss = None
+ if labels is not None:
+ labels = labels.to(logits.device)
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+__all__ = ["Zamba2ForCausalLM", "Zamba2ForSequenceClassification", "Zamba2Model", "Zamba2PreTrainedModel"]
diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py
new file mode 100644
index 00000000000000..2cf8fcde8cc952
--- /dev/null
+++ b/src/transformers/models/zamba2/modular_zamba2.py
@@ -0,0 +1,1874 @@
+# coding=utf-8
+# Copyright 2024 Zyphra Technologies and the HuggingFace Inc. team. All rights reserved.
+#
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from itertools import cycle
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...configuration_utils import PretrainedConfig
+from ...modeling_flash_attention_utils import _flash_attention_forward
+from ...modeling_outputs import BaseModelOutputWithPast
+from ...modeling_utils import PreTrainedModel
+from ...utils import (
+ add_start_docstrings,
+ add_start_docstrings_to_model_forward,
+ is_flash_attn_greater_or_equal_2_10,
+ logging,
+)
+from ...utils.import_utils import (
+ is_causal_conv1d_available,
+ is_mamba_ssm_available,
+)
+from ..gemma.modeling_gemma import GemmaRotaryEmbedding
+from ..llama.modeling_llama import apply_rotary_pos_emb
+from ..mamba2.modeling_mamba2 import MambaRMSNormGated, pad_tensor_by_size, reshape_into_chunks, segment_sum
+from ..zamba.modeling_zamba import (
+ ZambaAttention,
+ ZambaAttentionDecoderLayer,
+ ZambaForCausalLM,
+ ZambaForSequenceClassification,
+ ZambaHybridDynamicCache,
+ ZambaHybridLayer,
+ ZambaMambaDecoderLayer,
+ ZambaMLP,
+ ZambaModel,
+ ZambaRMSNorm,
+ repeat_kv,
+)
+
+
+if is_mamba_ssm_available():
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
+else:
+ selective_state_update, mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined = None, None, None
+
+if is_causal_conv1d_available():
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+else:
+ causal_conv1d_update, causal_conv1d_fn = None, None
+
+is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
+
+
+_CONFIG_FOR_DOC = "Zyphra/Zamba2-2.7B"
+
+logger = logging.get_logger(__name__)
+
+
+class Zamba2Config(PretrainedConfig):
+ r"""
+ This is the configuration class to store the configuration of a [`Zamba2Model`]. It is used to instantiate a
+ Zamba2 model according to the specified arguments, defining the model architecture. Instantiating a configuration
+ with the defaults will yield a similar configuration to that of the Zamba2 model.
+
+ [Zyphra/Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
+
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
+ documentation from [`PretrainedConfig`] for more information.
+ Args:
+ vocab_size (`int`, *optional*, defaults to 32000):
+ Vocabulary size of the Zamba2 model. Defines the number of different tokens that can be represented by the
+ `inputs_ids` passed when calling [`Zamba2Model`]
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
+ The maximum sequence length that this model might ever be used with.
+ hidden_size (`int`, *optional*, defaults to 2560):
+ Dimension of the hidden representations.
+ num_hidden_layers (`int`, *optional*, defaults to 54):
+ Number of hidden layers in the model.
+ layers_block_type (`list`, *optional*):
+ List of layer types, which can be either "mamba" or "hybrid".
+ mamba_d_state (`int`, *optional*, defaults to 64): shape of the state space latents.
+ mamba_d_conv (`int`, *optional*, defaults to 4): Size of the convolution kernel.
+ mamba_expand (`int`, *optional*, defaults to 2): Expanding factor used to determine the intermediate size.
+ mamba_ngroups (`int`, *optional*, defaults to 1):
+ Number of groups for the evolution matrices of mamba 2.
+ time_step_min (`float`, *optional*, defaults to 0.001):
+ Minimum `time_step` used to bound `dt_proj.bias`.
+ time_step_max (`float`, *optional*, defaults to 0.1):
+ Maximum `time_step` used to bound `dt_proj.bias`.
+ time_step_floor (`float`, *optional*, defaults to 0.0001):
+ Minimum clamping value of the `dt_proj.bias` layer initialization.
+ time_step_limit (`tuple`, *optional*):
+ Accepted range of time step values.
+ n_mamba_heads (`int`, *optional*, defaults to 8):
+ Number of heads for the evolution matrices of mamba 2.
+ use_conv_bias (`bool`, *optional*, defaults to `True`):
+ Whether or not to use bias in the convolution layer of the mixer block.
+ chunk_size (`int`, *optional*, defaults to 256):
+ Size of the chunks that will comprise the sequence.
+ add_bias_linear (`bool`, *optional*, defaults to `False`):
+ Flag indicating whether or not to use bias in various layers
+ intermediate_size (`int`, *optional*, defaults to 4 * hidden_size):
+ Dimension of the MLP representations.
+ hidden_act (`str`, *optional*, defaults to `"gelu"`):
+ The non-linear activation function (function or string) in the MLP.
+ num_attention_heads (`int`, *optional*, defaults to 32):
+ Number of attention heads for each attention layer in the Transformer decoder.
+ num_key_value_heads (`int`, *optional*):
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
+ `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
+ by meanpooling all the original heads within that group. For more details checkout [this
+ paper](https://arxiv.org/pdf/2305.13245.pdf).
+ attention_dropout (`float`, *optional*, defaults to 0.0):
+ The dropout ratio for the attention probabilities.
+ num_mem_blocks (`int`, *optional*, defaults to 1):
+ Number of unshared transformer blocks.
+ use_shared_mlp_adapter (`bool`, *optional*, defaults to `False`):
+ If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the shared MLP's.
+ use_shared_attention_adapter (`bool`, *optional*, defaults to `False`):
+ If True, unshared adapters (formally the same as LoRA but used in the base model) will be added to the q, k, v projectors in the shared attention layers.
+ adapter_rank (`int`, *optional*, defaults to 128):
+ Rank of the adapter in the shared MLP and shared attention layers.
+ use_mem_rope (`bool`, *optional*, defaults to `False`):
+ If True, includes RoPE in the shared attention layers.
+ rope_theta (`float`, *optional*, defaults to `10000.0`):
+ The base period of the RoPE embeddings.
+ initializer_range (`float`, *optional*, defaults to 0.02):
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
+ The epsilon used by the rms normalization layers.
+ use_cache (`bool`, *optional*, defaults to `True`):
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
+ relevant if `config.is_decoder=True`.
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated. If an
+ integer value, only last `num_logits_to_keep` logits will be calculated. Default is 1 because only the
+ logits of the last prompt token are needed for generation. For long sequences, the logits for the entire
+ sequence may use a lot of memory so, setting `num_logits_to_keep=1` will reduce memory footprint
+ significantly.
+ pad_token_id (`int`, *optional*, defaults to 0):
+ The id of the padding token.
+ bos_token_id (`int`, *optional*, defaults to 1):
+ The id of the "beginning-of-sequence" token.
+ eos_token_id (`int`, *optional*, defaults to 2):
+ The id of the "end-of-sequence" token.
+ use_long_context (`bool`, *optional*, defaults to `False`):
+ Activates the context-extended version of Zamba by modifying RoPE.
+ ```python
+ >>> from transformers import Zamba2Model, Zamba2Config
+ >>> # Initializing a Zamba2-2.7B style configuration
+ >>> configuration = Zamba2Config()
+ >>> # Initializing a model from the Zamba2-2.7B style configuration
+ >>> model = Zamba2Model(configuration)
+ >>> # Accessing the model configuration
+ >>> configuration = model.config
+ """
+
+ model_type = "zamba2"
+ keys_to_ignore_at_inference = ["past_key_values"]
+
+ def __init__(
+ self,
+ vocab_size=32000,
+ max_position_embeddings=4096,
+ hidden_size=2560,
+ num_hidden_layers=54,
+ layers_block_type=None,
+ mamba_d_state=64,
+ mamba_d_conv=4,
+ mamba_expand=2,
+ mamba_ngroups=1,
+ time_step_min=0.001,
+ time_step_max=0.1,
+ time_step_floor=1e-4,
+ time_step_limit=None,
+ n_mamba_heads=8,
+ use_conv_bias=True,
+ chunk_size=256,
+ add_bias_linear=False,
+ intermediate_size=None,
+ hidden_act="gelu",
+ num_attention_heads=32,
+ num_key_value_heads=None,
+ attention_dropout=0.0,
+ num_mem_blocks=1,
+ use_shared_mlp_adapter=False,
+ use_shared_attention_adapter=False,
+ adapter_rank=128,
+ use_mem_rope=False,
+ rope_theta=10000,
+ initializer_range=0.02,
+ rms_norm_eps=1e-5,
+ use_cache=True,
+ num_logits_to_keep=1,
+ pad_token_id=0,
+ bos_token_id=1,
+ eos_token_id=2,
+ use_long_context=False,
+ **kwargs,
+ ):
+ super().__init__(
+ pad_token_id=pad_token_id,
+ bos_token_id=bos_token_id,
+ eos_token_id=eos_token_id,
+ **kwargs,
+ )
+ self.vocab_size = vocab_size
+ self.max_position_embeddings = max_position_embeddings
+ self.hidden_size = hidden_size
+ if intermediate_size is None:
+ self.intermediate_size = 4 * hidden_size
+ else:
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.num_mem_blocks = num_mem_blocks
+ self.attention_hidden_size = 2 * hidden_size
+ self.attention_head_dim = 2 * self.hidden_size // self.num_attention_heads
+ self.attention_dropout = attention_dropout
+ self.use_mem_rope = use_mem_rope
+ self.use_long_context = use_long_context
+ if use_mem_rope and use_long_context:
+ a = 8
+ rope_theta = rope_theta * a ** (self.attention_head_dim / (self.attention_head_dim - 2))
+ self.rope_theta = rope_theta
+ self.mamba_d_state = mamba_d_state
+ self.mamba_d_conv = mamba_d_conv
+ self.mamba_expand = mamba_expand
+ self.add_bias_linear = add_bias_linear
+ self.mamba_ngroups = mamba_ngroups
+ self.n_mamba_heads = n_mamba_heads
+ self.mamba_headdim = int(mamba_expand * hidden_size) // n_mamba_heads
+ self.use_conv_bias = use_conv_bias
+ self.chunk_size = chunk_size
+ self.time_step_limit = time_step_limit
+ self.use_shared_mlp_adapter = use_shared_mlp_adapter
+ self.use_shared_attention_adapter = use_shared_attention_adapter
+ self.adapter_rank = adapter_rank
+ self.time_step_min = time_step_min
+ self.time_step_max = time_step_max
+ self.time_step_floor = time_step_floor
+ if use_long_context:
+ self.max_position_embeddings = 16384
+ if num_key_value_heads is None:
+ num_key_value_heads = num_attention_heads
+ self.num_key_value_heads = num_key_value_heads
+ self.num_attention_heads = num_attention_heads
+ self.kv_channels = self.hidden_size // self.num_attention_heads
+ self.num_query_groups = self.num_attention_heads
+ # Below, "mamba" stands for mamba layer, "hybrid" stands for hybrid layer (composed by a shared transformer followed by mamba layer)
+ if layers_block_type is None:
+ self.layers_block_type = (
+ ["mamba"]
+ + (["mamba"] * 5 + ["hybrid"]) * 7
+ + ["mamba"] * 4
+ + ["hybrid"]
+ + ["mamba"] * 3
+ + ["hybrid"]
+ + ["mamba"] * 2
+ )
+ else:
+ self.layers_block_type = layers_block_type
+ self.initializer_range = initializer_range
+ self.rms_norm_eps = rms_norm_eps
+ self.use_cache = use_cache
+ self.num_logits_to_keep = num_logits_to_keep
+
+
+class Zamba2RMSNormGated(MambaRMSNormGated):
+ pass
+
+
+class Zamba2RMSNorm(ZambaRMSNorm):
+ pass
+
+
+def count_mem_blocks_in_config(config: Zamba2Config):
+ """
+ Count number of shared blocks
+ """
+ num_gs = 0
+ for val in config.layers_block_type:
+ if val == "hybrid":
+ num_gs += 1
+ return num_gs
+
+
+def layer_type_list(config: Zamba2Config):
+ """
+ Returns list of layer ids containing hybrid layers
+ """
+ output_list = []
+ for index, type in enumerate(config.layers_block_type):
+ if type == "hybrid":
+ output_list.append(index)
+ return output_list
+
+
+class Zamba2HybridDynamicCache(ZambaHybridDynamicCache):
+ """
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
+ (which has a constant shape regardless of seq_len).
+
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
+ """
+
+ def __init__(
+ self, config: Zamba2Config, batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[str] = None
+ ):
+ self.dtype = dtype
+ self.layers_block_type = config.layers_block_type
+ self.has_previous_state = False
+ self.intermediate_size = int(config.mamba_expand * config.hidden_size)
+ self.ssm_state_size = config.mamba_d_state
+ self.conv_kernel_size = config.mamba_d_conv
+ self.n_mamba_heads = config.n_mamba_heads
+ self.transformer_layers = []
+ self._modules = {}
+ self._parameters = {}
+ self._buffers = {}
+ self.conv_states = {
+ i: torch.zeros(
+ batch_size,
+ self.intermediate_size + 2 * config.mamba_ngroups * config.mamba_d_state,
+ self.conv_kernel_size,
+ device=device,
+ dtype=dtype,
+ )
+ for i in range(config.num_hidden_layers)
+ }
+ self.ssm_states = {
+ i: torch.zeros(
+ batch_size, self.n_mamba_heads, config.mamba_headdim, self.ssm_state_size, device=device, dtype=dtype
+ )
+ for i in range(config.num_hidden_layers)
+ }
+ for i in range(config.num_hidden_layers):
+ if self.layers_block_type[i] == "hybrid":
+ self.transformer_layers.append(i)
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
+
+ def update_conv_state(
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
+ ) -> torch.Tensor:
+ conv_state = self.conv_states[layer_idx]
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
+
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
+ self.conv_states[layer_idx].zero_()
+ self.conv_states[layer_idx] += conv_state
+ return self.conv_states[layer_idx]
+
+ def reset(self):
+ self.conv_states.zero_()
+ self.ssm_states.zero_()
+
+
+class Zamba2RotaryEmbedding(GemmaRotaryEmbedding):
+ def __init__(
+ self,
+ config: Zamba2Config,
+ device=None,
+ ):
+ super().__init__(config, device)
+ self.rope_kwargs = {"base": config.rope_theta, "dim": config.attention_head_dim}
+ inv_freq, self.attention_scaling = self.rope_init_fn(config=None, device=device, **self.rope_kwargs)
+
+
+class Zamba2Attention(ZambaAttention):
+ """
+ Multi-headed attention from 'Attention Is All You Need' paper.
+
+ Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
+ The input dimension here is attention_hidden_size = 2 * hidden_size, and head_dim = attention_hidden_size // num_heads.
+ The extra factor of 2 comes from the input being the concatenation of original_hidden_states with the output of the previous (mamba) layer
+ (see fig. 2 in https://arxiv.org/pdf/2405.16712).
+ Additionally, replaced
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) with
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim/2)
+ Finally, this attention layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this
+ layer is tied, un-tied adapters (formally the same as LoRA but used in the base model) modules are added to the q, k, v projectors to increase
+ expressivity with a small memory overhead (see Fig. 2 of https://arxiv.org/pdf/2411.15242).
+ """
+
+ def __init__(
+ self,
+ config: Zamba2Config,
+ layer_idx: Optional[int] = None,
+ num_fwd_mem_blocks: int = None,
+ block_id: int = None,
+ ):
+ super().__init__(config, layer_idx)
+ self.num_fwd_mem_blocks = num_fwd_mem_blocks
+ self.layer_block_map = layer_type_list(config)
+ self.block_id = block_id
+ self.is_causal = True
+
+ if config.use_shared_attention_adapter:
+ self.linear_q_adapter_list = nn.ModuleList([])
+ self.linear_k_adapter_list = nn.ModuleList([])
+ self.linear_v_adapter_list = nn.ModuleList([])
+
+ for i in range(self.num_fwd_mem_blocks):
+ if i % config.num_mem_blocks == block_id:
+ linear_q_adapter = nn.Sequential(
+ nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
+ nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
+ )
+ linear_k_adapter = nn.Sequential(
+ nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
+ nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
+ )
+ linear_v_adapter = nn.Sequential(
+ nn.Linear(self.attention_hidden_size, self.config.adapter_rank, bias=False),
+ nn.Linear(self.config.adapter_rank, self.attention_hidden_size, bias=False),
+ )
+ else:
+ linear_q_adapter = nn.Identity()
+ linear_k_adapter = nn.Identity()
+ linear_v_adapter = nn.Identity()
+ self.linear_q_adapter_list.append(linear_q_adapter)
+ self.linear_k_adapter_list.append(linear_k_adapter)
+ self.linear_v_adapter_list.append(linear_v_adapter)
+
+ self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)}
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_idx: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ if self.config.use_shared_attention_adapter:
+ adapter_layer_idx = self.layer_dic[layer_idx]
+ query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states)
+ key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states)
+ value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.config.use_mem_rope:
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim / 2)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+ attn_weights = attn_weights + causal_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
+# Added softmax_scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of self._flash_attention_forward
+# dropped use_sliding_windows from the arguments of self._flash_attention_forward
+class Zamba2FlashAttention2(Zamba2Attention):
+ """
+ Zamba2 flash attention module. This module inherits from `Zamba2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_idx: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ if self.config.use_shared_attention_adapter:
+ adapter_layer_idx = self.layer_dic[layer_idx]
+ query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states)
+ key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states)
+ value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.config.use_mem_rope:
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+ softmax_scale = 1 / math.sqrt(self.head_dim / 2)
+
+ attn_output = _flash_attention_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ dropout=dropout_rate,
+ softmax_scale=softmax_scale,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.attention_hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+# Adapted from transformers.models.mistral.modeling_mistral.MistralAttention:
+# added scale = 1 / (query_states.shape[-1]/2)**0.5 to the arguments of torch.nn.functional.scaled_dot_product_attention
+class Zamba2SdpaAttention(Zamba2Attention):
+ """
+ Zamba2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
+ `Zamba2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
+ SDPA API.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_idx: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if output_attentions:
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
+ logger.warning_once(
+ "Zamba2Model is using Zamba2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
+ )
+ return super().forward(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ if self.config.use_shared_attention_adapter:
+ adapter_layer_idx = self.layer_dic[layer_idx]
+ query_states = query_states + self.linear_q_adapter_list[adapter_layer_idx](hidden_states)
+ key_states = key_states + self.linear_k_adapter_list[adapter_layer_idx](hidden_states)
+ value_states = value_states + self.linear_v_adapter_list[adapter_layer_idx](hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.config.use_mem_rope:
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(key_states, value_states, layer_idx)
+
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ causal_mask = attention_mask
+ if attention_mask is not None:
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
+
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
+ if query_states.device.type == "cuda" and attention_mask is not None:
+ query_states = query_states.contiguous()
+ key_states = key_states.contiguous()
+ value_states = value_states.contiguous()
+
+ softmax_scale = 1 / math.sqrt(self.head_dim / 2)
+
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attn_mask=causal_mask,
+ dropout_p=self.attention_dropout if self.training else 0.0,
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
+ scale=softmax_scale,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.view(bsz, q_len, self.attention_hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
+ZAMBA2_ATTENTION_CLASSES = {
+ "eager": Zamba2Attention,
+ "flash_attention_2": Zamba2FlashAttention2,
+ "sdpa": Zamba2SdpaAttention,
+}
+
+
+class Zamba2MambaMixer(nn.Module):
+ """
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
+ and is why Mamba is called **selective** state spaces)
+ """
+
+ def __init__(self, config: Zamba2Config, layer_idx: int = None):
+ super().__init__()
+ self.config = config
+ self.hidden_size = config.hidden_size
+ self.ssm_state_size = config.mamba_d_state
+ self.conv_kernel_size = config.mamba_d_conv
+ self.intermediate_size = int(config.mamba_expand * self.hidden_size)
+ self.layer_idx = layer_idx
+ self.use_conv_bias = config.use_conv_bias
+ self.activation = "silu"
+ self.act = nn.SiLU()
+
+ self.n_groups = config.mamba_ngroups
+ self.head_dim = config.mamba_headdim
+ self.num_heads = self.config.n_mamba_heads
+ self.chunk_size = config.chunk_size
+
+ self.time_step_limit = config.time_step_limit
+ self.time_step_min = config.time_step_min
+ self.time_step_max = config.time_step_max
+
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
+ self.conv1d = nn.Conv1d(
+ in_channels=self.conv_dim,
+ out_channels=self.conv_dim,
+ bias=True,
+ kernel_size=config.mamba_d_conv,
+ groups=self.conv_dim,
+ padding=config.mamba_d_conv - 1,
+ )
+
+ # projection of the input hidden states
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
+ self.in_proj = nn.Linear(
+ self.hidden_size,
+ projection_size,
+ bias=config.add_bias_linear,
+ )
+ # selective projection used to make dt, B and C input dependant
+
+ # time step projection (discretization)
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
+
+ # S4D real initialization. These are not discretized!
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
+ A = torch.arange(1, self.num_heads + 1)
+ self.A_log = nn.Parameter(torch.log(A))
+ self.A_log._no_weight_decay = True
+ self.norm = Zamba2RMSNormGated(self.intermediate_size, eps=1e-5)
+ self.D = nn.Parameter(torch.ones(self.num_heads))
+ self.D._no_weight_decay = True
+
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)
+
+ if not is_fast_path_available:
+ logger.warning_once(
+ "The fast path is not available because on of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
+ " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
+ " https://github.com/Dao-AILab/causal-conv1d"
+ )
+
+ def cuda_kernels_forward(
+ self,
+ hidden_states: torch.Tensor,
+ cache_params: Optional[Zamba2HybridDynamicCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ # set up dimensions for reshapes later
+
+ batch_size, seq_len, _ = hidden_states.shape
+ groups_time_state_size = self.n_groups * self.ssm_state_size
+ d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
+
+ # getting projected states from cache if it exists
+ if cache_params is not None and cache_params.has_previous_state:
+ in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
+ d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
+ split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
+ _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
+
+ hidden_states_B_C = causal_conv1d_update(
+ hidden_states_B_C,
+ cache_params.conv_states[self.layer_idx],
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.activation,
+ )
+
+ hidden_states, B, C = torch.split(
+ hidden_states_B_C,
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
+ dim=-1,
+ )
+ A = -torch.exp(self.A_log.float()) # (nheads,)
+
+ A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
+ B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
+ C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
+ hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
+ hidden_states = selective_state_update(
+ cache_params.ssm_states[self.layer_idx],
+ hidden_states_reshaped,
+ dt,
+ A,
+ B,
+ C,
+ D,
+ z=None,
+ dt_bias=dt_bias,
+ dt_softplus=True,
+ )
+ hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
+ hidden_states = self.norm(hidden_states, gate)
+ out = self.out_proj(hidden_states)[:, None, ...]
+ # if no cache is found, calling the kernel
+ else:
+ if attention_mask is not None and not torch.all(attention_mask == 1):
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ dtype = hidden_states.dtype
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+ # 1. Gated MLP's linear projection
+ projected_states = self.in_proj(hidden_states)
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
+ dt_limit_kwargs = {} if self.time_step_limit is None else {"dt_limit": self.time_step_limit}
+ if attention_mask is not None:
+ input_not_masked = torch.all(attention_mask == 1)
+ else:
+ input_not_masked = True
+
+ if self.training and cache_params is None and input_not_masked:
+ out, ssm_state = mamba_split_conv1d_scan_combined(
+ projected_states,
+ self.conv1d.weight.squeeze(1),
+ self.conv1d.bias,
+ self.dt_bias,
+ A,
+ D=self.D,
+ chunk_size=self.chunk_size,
+ seq_idx=None,
+ activation=self.activation,
+ rmsnorm_weight=self.norm.weight,
+ rmsnorm_eps=self.norm.variance_epsilon,
+ outproj_weight=self.out_proj.weight,
+ outproj_bias=self.out_proj.bias,
+ headdim=self.head_dim,
+ ngroups=self.n_groups,
+ norm_before_gate=False,
+ return_final_states=True,
+ **dt_limit_kwargs,
+ )
+
+ else:
+ gate, hidden_states_B_C, time_step = torch.split(
+ projected_states,
+ [self.intermediate_size, self.conv_dim, self.num_heads],
+ dim=-1,
+ )
+
+ # 1D Convolution
+ if cache_params is not None:
+ hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2)
+ conv_state = nn.functional.pad(
+ hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0)
+ )
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
+ hidden_states_B_C = self.act(
+ self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
+ else:
+ hidden_states_B_C = causal_conv1d_fn(
+ x=hidden_states_B_C.transpose(1, 2),
+ weight=self.conv1d.weight.squeeze(1),
+ bias=self.conv1d.bias,
+ activation=self.activation,
+ ).transpose(1, 2)[:, :seq_len]
+ hidden_states, B, C = torch.split(
+ hidden_states_B_C,
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
+ dim=-1,
+ )
+ if attention_mask is not None and not torch.all(attention_mask == 1):
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ dtype = hidden_states.dtype
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+ scan_output, ssm_state = mamba_chunk_scan_combined(
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
+ time_step,
+ A,
+ B.view(batch_size, seq_len, self.n_groups, -1),
+ C.view(batch_size, seq_len, self.n_groups, -1),
+ chunk_size=self.chunk_size,
+ D=self.D,
+ z=None,
+ seq_idx=None,
+ return_final_states=True,
+ dt_bias=self.dt_bias,
+ dt_softplus=True,
+ **dt_limit_kwargs,
+ )
+ if ssm_state is not None and cache_params is not None:
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
+ scan_output = scan_output.view(batch_size, seq_len, -1)
+ # Multiply "gate" branch and apply extra normalization layer
+ scan_output = self.norm(scan_output, gate)
+ out = self.out_proj(scan_output)
+ return out
+
+ # fmt: off
+ def torch_forward(self, input_states, cache_params: Optional[Zamba2HybridDynamicCache]=None, cache_position:Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None):
+ batch_size, seq_len, _ = input_states.shape
+ dtype = input_states.dtype
+ # Gated MLP's linear projection
+ if cache_params is not None and cache_params.has_previous_state:
+ projected_states = self.in_proj(input_states.squeeze(1))
+ else:
+ if attention_mask is not None and not torch.all(attention_mask==1):
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ input_states = (input_states * attention_mask[:, :, None]).to(dtype)
+ projected_states = self.in_proj(input_states)
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
+ _, _, gate, hidden_states, dt = projected_states.split(
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
+ )
+
+ # Convolution sequence transformation
+ if cache_params is not None:
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
+ ssm_state = ssm_state.to(hidden_states.device)
+ if cache_params.has_previous_state:
+ gate = gate.unsqueeze(1)
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
+ # handle batched generation - states are copied through
+ conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
+ hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
+ if self.use_conv_bias:
+ hidden_states += self.conv1d.bias
+ hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
+ else:
+ hidden_states = hidden_states.transpose(1,2)
+ conv_state = nn.functional.pad(
+ hidden_states,
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
+ )
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
+ hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
+ if attention_mask is not None and not torch.all(attention_mask==1):
+ dtype = hidden_states.dtype
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
+ else:
+ ssm_state = torch.zeros(
+ (batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
+ device=hidden_states.device, dtype=dtype
+ )
+ hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
+ hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
+ A = -torch.exp(self.A_log.float()) # [num_heads]
+ if cache_params is not None and cache_params.has_previous_state:
+ # Note: there is no need to pad parameter matrices here, as there is just one new token
+ # for batched generation
+ dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
+ dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
+ # [num_heads] -> [num_heads, head_dim]
+ dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
+
+ dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
+ dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
+ # [bsz, num_heads, head_dim, state_size]
+ dA = torch.exp(dt[..., None] * A)
+
+ # Discretize B
+ # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
+ # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
+ B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
+ B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
+ B = B.reshape(batch_size, -1, B.shape[-1])
+ # [bsz, num_heads, head_dim, state_size]
+ dB = dt[..., None] * B[..., None, :]
+
+ # Discretize x into dB
+ # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
+ hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
+ dBx = dB * hidden_states[..., None]
+
+ # State calculation
+ cache_params.ssm_states[self.layer_idx].copy_(
+ cache_params.ssm_states[self.layer_idx] * dA + dBx
+ )
+
+ # Subsequent output
+ # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
+ C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
+ C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
+ C = C.reshape(batch_size, -1, C.shape[-1])
+ # [bsz, num_heads, head_dim]
+
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
+ # Reshape ssm_states to merge the first two dimensions
+ ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
+ C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
+ y = torch.bmm(ssm_states_reshaped, C_reshaped)
+ y = y.view(batch_size, self.num_heads, self.head_dim)
+
+ # D skip connection
+ # [num_heads] -> [num_heads, head_dim]
+ D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
+ y = (y + hidden_states * D).to(y.dtype)
+
+ # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
+ y = y.reshape(batch_size, -1)[:, None, ...]
+ else:
+ # begin ssd naive implementation without einsums
+ dt = nn.functional.softplus(dt + self.dt_bias)
+ dt = torch.clamp(dt, self.time_step_min)
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
+ C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
+ B = B.repeat(1, 1, self.num_heads // self.n_groups, 1)
+ C = C.repeat(1, 1, self.num_heads // self.n_groups, 1)
+ pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
+
+ D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
+
+ # Discretize x and A
+ hidden_states = hidden_states * dt[..., None]
+ A = A.to(hidden_states.dtype) * dt
+
+ # Rearrange into blocks/chunks
+ hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
+
+
+ # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
+ A = A.permute(0, 3, 1, 2)
+ A_cumsum = torch.cumsum(A, dim=-1)
+
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
+ # This is the analog of a causal mask
+ L = torch.exp(segment_sum(A))
+
+ # First, contraction of C and B to get G (attention-weights like)
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
+ G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
+
+
+ # Step 2: Compute M, equivalent to applying attention mask to weights
+ M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
+ M = M_intermediate.sum(dim=-1)
+
+ # Step 3: Compute Y_diag (apply to values)
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
+
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
+
+ decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
+ B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
+ # permute back B * decay states
+ states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
+ if cache_params is not None and cache_params.has_previous_state:
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
+ else:
+ previous_states = torch.zeros_like(states[:, :1])
+ states = torch.cat([previous_states, states], dim=1)
+ decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
+
+ states_permuted = states.permute(0, 2, 1, 3, 4)
+ result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
+ new_states = result.permute(0, 2, 1, 3, 4)
+ states, ssm_state = new_states[:, :-1], new_states[:, -1]
+
+ # Compute state -> output conversion per chunk
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
+ state_decay_out = torch.exp(A_cumsum)
+ # compute Yoff
+ C_times_states = (C[..., None, :] * states[:, :, None, ...])
+ state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
+ Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
+
+ y = Y_diag + Y_off
+ # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
+ y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
+
+ y = y + D_residual
+ # Cutting off padded chunks
+ if pad_size > 0:
+ y = y[:, :seq_len, :, :]
+ y = y.reshape(batch_size, seq_len, -1)
+ if ssm_state is not None and cache_params is not None:
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
+
+ scan_output = self.norm(y, gate)
+
+ # end ssd naive
+
+ # 4. Final linear projection
+ contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
+ return contextualized_states
+ # fmt: on
+
+ def forward(
+ self,
+ hidden_states,
+ cache_params: Optional[Zamba2HybridDynamicCache] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ ):
+ if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
+ return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
+
+ return self.torch_forward(hidden_states, cache_params, cache_position, attention_mask)
+
+
+class Zamba2MLP(ZambaMLP):
+ def __init__(self, config: Zamba2Config, num_fwd_mem_blocks=None, block_id: int = None):
+ """
+ This MLP layer contributes to tied transformer blocks aimed to increasing compute without increasing model size. Because this layer
+ is tied, un-tied adapter modules (formally same as LoRA, but used in the base model) are added to the up and gate projectors to increase expressivity with a small memory overhead.
+ """
+ super().__init__(config)
+ self.config = config
+ self.num_fwd_mem_blocks = num_fwd_mem_blocks
+ self.block_id = block_id
+
+ def gated_act_fn(x):
+ x = torch.chunk(x, 2, dim=-1)
+ return self.act_fn(x[0]) * x[1]
+
+ self.gated_act_fn = gated_act_fn
+
+ del self.gate_proj
+ del self.up_proj
+ del self.down_proj
+ self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=config.add_bias_linear)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.add_bias_linear)
+
+ if self.config.use_shared_mlp_adapter:
+ self.gate_up_proj_adapter_list = nn.ModuleList([])
+ for i in range(self.num_fwd_mem_blocks):
+ if i % config.num_mem_blocks == block_id:
+ gate_up_proj_adapter = nn.Sequential(
+ nn.Linear(self.config.hidden_size, self.config.adapter_rank, bias=False),
+ nn.Linear(self.config.adapter_rank, 2 * self.intermediate_size, bias=False),
+ )
+ else:
+ gate_up_proj_adapter = nn.Identity()
+ self.gate_up_proj_adapter_list.append(gate_up_proj_adapter)
+
+ layer_block_map = layer_type_list(config)
+ self.layer_dic = {value: index for index, value in enumerate(layer_block_map)}
+
+ def forward(self, hidden_state, layer_idx=None):
+ gate_up_state = self.gate_up_proj(hidden_state)
+ if self.config.use_shared_mlp_adapter:
+ layer_idx = self.layer_dic[layer_idx]
+ gate_up_state = gate_up_state + self.gate_up_proj_adapter_list[layer_idx](hidden_state)
+
+ hidden_state = self.gated_act_fn(gate_up_state)
+ output = self.down_proj(hidden_state)
+ return output
+
+
+class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer):
+ def __init__(self, config: Zamba2Config, block_id: int = None, layer_idx: Optional[int] = None):
+ self.block_id = block_id
+ num_gs = count_mem_blocks_in_config(config)
+ super().__init__(config, layer_idx)
+ self.self_attn = ZAMBA2_ATTENTION_CLASSES[config._attn_implementation](
+ config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id
+ )
+ self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ original_hidden_states: torch.Tensor,
+ layer_idx: int,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): output of previous Mamba layer of shape `(batch, seq_len, embed_dim)`
+ original_hidden_states (`torch.FloatTensor`): word embedding output of shape `(batch, seq_len, embed_dim)`.
+ This is concatenated with `hidden_states` (which is the output of the previous (mamba) layer). The
+ concatenated tensor is then used as input of the pre-attention RMSNorm
+ (see fig. 2 in https://arxiv.org/pdf/2405.16712).
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ """
+ hidden_states = torch.concatenate([hidden_states, original_hidden_states], dim=-1)
+ hidden_states = self.input_layernorm(hidden_states)
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ layer_idx=layer_idx,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ **kwargs,
+ )
+
+ hidden_states = self.pre_ff_layernorm(hidden_states)
+ hidden_states = self.feed_forward(hidden_states, layer_idx)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (present_key_value,)
+
+ return outputs
+
+
+class Zamba2MambaDecoderLayer(ZambaMambaDecoderLayer):
+ def __init__(self, config: Zamba2Config, layer_idx: int):
+ super().__init__(config, layer_idx)
+ self.mamba = Zamba2MambaMixer(config=config, layer_idx=layer_idx)
+ self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ original_hidden_states: Optional[torch.Tensor] = None,
+ layer_idx: int = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[torch.LongTensor] = None,
+ transformer_hidden_states: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ transformer_hidden_states (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Output of the previous shared transformer layer (if present) of shape `(batch_size, seq_len, embed_dim)`.
+ """
+
+ residual = hidden_states
+
+ # `transformer_hidden_states` is the output from shared transformer + linear layer (see fig. 2 in https://arxiv.org/pdf/2405.16712).
+ # `transformer_hidden_states` is then added to the input to the mamba layer below (as described in eq. (6) of https://arxiv.org/pdf/2405.16712).
+ hidden_states = (
+ hidden_states + transformer_hidden_states if transformer_hidden_states is not None else hidden_states
+ )
+ hidden_states = self.input_layernorm(hidden_states)
+
+ hidden_states = self.mamba(
+ hidden_states=hidden_states,
+ cache_params=past_key_value,
+ attention_mask=attention_mask,
+ )
+
+ self_attn_weights = None
+
+ # residual connection after mamba
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (past_key_value,)
+
+ return outputs
+
+
+class Zamba2HybridLayer(ZambaHybridLayer):
+ def __init__(
+ self, shared_transformer: Zamba2AttentionDecoderLayer, linear: nn.Linear, mamba: Zamba2MambaDecoderLayer
+ ):
+ super().__init__(shared_transformer, linear, mamba)
+ del self.shared_transf
+ self.shared_transformer = shared_transformer
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ original_hidden_states: Optional[torch.Tensor] = None,
+ layer_idx: int = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Zamba2HybridDynamicCache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ original_hidden_states (`torch.FloatTensor`): word embedding output that will be concatenated with
+ hidden activations to form the input of the shared transformer layer.
+ layer_idx (`int`): layer number.
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
+ `(batch, sequence_length)` where padding elements are indicated by 0.
+ past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
+ with `head_dim` being the embedding dimension of each attention head.
+ """
+
+ layer_outputs = self.shared_transformer(
+ hidden_states,
+ original_hidden_states=original_hidden_states,
+ layer_idx=layer_idx,
+ attention_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ transformer_hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ self_attn_weights = layer_outputs[1]
+
+ transformer_hidden_states = self.linear(transformer_hidden_states)
+
+ layer_outputs = self.mamba_decoder(
+ hidden_states,
+ transformer_hidden_states=transformer_hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ if output_attentions:
+ layer_outputs = (layer_outputs[0], self_attn_weights) + layer_outputs[2:]
+
+ return layer_outputs
+
+
+ZAMBA2_START_DOCSTRING = r"""
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
+ etc.)
+
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
+ and behavior.
+
+ Parameters:
+ config ([`Zamba2Config`]):
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
+ load the weights associated with the model, only the configuration. Check out the
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
+"""
+
+
+@add_start_docstrings(
+ "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.",
+ ZAMBA2_START_DOCSTRING,
+)
+class Zamba2PreTrainedModel(PreTrainedModel):
+ config_class = Zamba2Config
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["Zamba2AttentionDecoderLayer", "Zamba2MambaDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_cache_class = True # Note: only supports Zamba2HybridDynamicCache
+ _is_stateful = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, Zamba2MambaMixer):
+ module.A_log._no_weight_decay = True
+ module.D._no_weight_decay = True
+
+ dt = torch.exp(
+ torch.rand(self.config.n_mamba_heads)
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
+ + math.log(self.config.time_step_min)
+ ).clamp(min=self.config.time_step_floor)
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
+
+ with torch.no_grad():
+ module.dt_bias.copy_(inv_dt)
+ module.dt_bias._no_reinit = True
+
+
+ZAMBA2_INPUTS_DOCSTRING = r"""
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
+ it.
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are input IDs?](../glossary#input-ids)
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
+ [`PreTrainedTokenizer.__call__`] for details.
+
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
+ `past_key_values`).
+
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
+ information on the default strategy.
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
+ config.n_positions - 1]`.
+
+ [What are position IDs?](../glossary#position-ids)
+ past_key_values (`Zamba2HybridDynamicCache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
+ A Zamba2HybridDynamicCache object containing pre-computed hidden-states (keys and values in the
+ self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
+ `past_key_values` input) to speed up sequential decoding.
+ Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
+ Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
+ `(batch_size, d_inner, d_state)` respectively.
+ See the `Zamba2HybridDynamicCache` class for more details.
+
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
+ `input_ids` of shape `(batch_size, sequence_length)`.
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
+ model's internal embedding lookup matrix.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
+ `past_key_values`).
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
+ tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
+ more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
+ the complete sequence length.
+"""
+
+
+@add_start_docstrings(
+ "The bare Zamba2 Model outputting raw hidden-states without any specific head on top.",
+ ZAMBA2_START_DOCSTRING,
+)
+class Zamba2Model(ZambaModel, Zamba2PreTrainedModel):
+ """
+ Model consisting of *config.num_hidden_layers* layers.
+
+ Args:
+ config: Zamba2Config
+ """
+
+ def __init__(self, config: Zamba2Config):
+ Zamba2PreTrainedModel.__init__(self, config)
+ self.config = config
+ self.padding_idx = config.pad_token_id
+ self.vocab_size = config.vocab_size
+
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
+ blocks = [Zamba2AttentionDecoderLayer(config, block_id=k) for k in range(config.num_mem_blocks)]
+ mamba_layers = []
+ linear_layers = []
+ self.layers_block_type = config.layers_block_type
+ for i in range(config.num_hidden_layers):
+ if config.layers_block_type[i] == "mamba":
+ mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i))
+ elif config.layers_block_type[i] == "hybrid":
+ linear_layers.append(nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False))
+ mamba_layers.append(Zamba2MambaDecoderLayer(config, layer_idx=i))
+ mamba_layers = iter(mamba_layers)
+ linear_layers = iter(linear_layers)
+ blocks = cycle(blocks)
+ layers = self.get_layers(blocks, linear_layers, mamba_layers)
+ self.layers = nn.ModuleList(layers)
+
+ self._attn_implementation = config._attn_implementation
+ self.final_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
+ if config.use_mem_rope:
+ if config.use_long_context:
+ logger.warning_once(
+ "`use_long_context` set to `True`: using rescaled `rope_theta` and extended `max_position_embeddings`."
+ )
+ self.rotary_emb = Zamba2RotaryEmbedding(config)
+ self.gradient_checkpointing = False
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_layers(self, blocks, linear_layers, mamba_layers):
+ layers = []
+ self._tied_weights_keys = []
+ for layer_id, layer_type in enumerate(self.layers_block_type):
+ if layer_type == "hybrid":
+ block = next(blocks)
+ if self.config.num_mem_blocks * len(layer_type_list(self.config)) > 1:
+ prefix_name = f"layers.{layer_id}."
+ tied_keys = [
+ "shared_transformer.self_attn.q_proj.weight",
+ "shared_transformer.self_attn.k_proj.weight",
+ "shared_transformer.self_attn.v_proj.weight",
+ "shared_transformer.self_attn.o_proj.weight",
+ "shared_transformer.feed_forward.gate_up_proj.weight",
+ "shared_transformer.feed_forward.down_proj.weight",
+ "shared_transformer.input_layernorm.weight",
+ "shared_transformer.pre_ff_layernorm.weight",
+ ]
+ self._tied_weights_keys = [*self._tied_weights_keys, *[prefix_name + key for key in tied_keys]]
+ if self.config.use_shared_mlp_adapter:
+ tied_keys_adapter = []
+ adapter_id = 0
+ for _layer_type in self.layers_block_type:
+ if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
+ tied_keys_adapter.append(
+ "shared_transformer.feed_forward.gate_up_proj_adapter_list."
+ + str(adapter_id)
+ + ".0.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.feed_forward.gate_up_proj_adapter_list."
+ + str(adapter_id)
+ + ".1.weight"
+ )
+ adapter_id += 1
+ self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter]
+ if self.config.use_shared_attention_adapter:
+ tied_keys_adapter = []
+ adapter_id = 0
+ for _layer_type in self.layers_block_type:
+ if _layer_type == "hybrid" and adapter_id % self.config.num_mem_blocks == block.block_id:
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_q_adapter_list."
+ + str(adapter_id)
+ + ".0.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_k_adapter_list."
+ + str(adapter_id)
+ + ".0.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_v_adapter_list."
+ + str(adapter_id)
+ + ".0.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_q_adapter_list."
+ + str(adapter_id)
+ + ".1.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_k_adapter_list."
+ + str(adapter_id)
+ + ".1.weight"
+ )
+ tied_keys_adapter.append(
+ "shared_transformer.self_attn.linear_v_adapter_list."
+ + str(adapter_id)
+ + ".1.weight"
+ )
+ adapter_id += 1
+ self._tied_weights_keys = [*self._tied_weights_keys, *tied_keys_adapter]
+ layers.append(Zamba2HybridLayer(block, next(linear_layers), next(mamba_layers)))
+ else:
+ layers.append(next(mamba_layers))
+ return layers
+
+ @add_start_docstrings_to_model_forward(ZAMBA2_INPUTS_DOCSTRING)
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Zamba2HybridDynamicCache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+ )
+
+ if self.gradient_checkpointing and self.training and use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
+ )
+ use_cache = False
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ hidden_states = inputs_embeds
+
+ original_hidden_states = torch.clone(inputs_embeds)
+ # original_hidden_states: word embedding output that will be concatenated with hidden activations to form the input of the shared transformer layer
+
+ if use_cache and past_key_values is None:
+ logger.warning_once(
+ "Zamba2 requires an initialized `Zamba2HybridDynamicCache` to return a cache. None was "
+ "provided, so no cache will be returned."
+ )
+
+ if cache_position is None:
+ cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
+ if position_ids is None:
+ position_ids = cache_position.unsqueeze(0)
+
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
+
+ # create position embeddings to be shared across the decoder layers
+ if self.config.use_mem_rope:
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+ else:
+ position_embeddings = None
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+
+ for layer_idx, layer in enumerate(self.layers):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ original_hidden_states,
+ layer_idx,
+ attention_mask,
+ causal_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states,
+ original_hidden_states=original_hidden_states,
+ layer_idx=layer_idx,
+ attention_mask=attention_mask,
+ causal_mask=causal_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ if layer_outputs[1] is not None:
+ # append attentions only of attention layers. Mamba layers return `None` as the attention weights
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.final_layernorm(hidden_states)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if past_key_values and not past_key_values.has_previous_state:
+ past_key_values.has_previous_state = True
+
+ next_cache = None if not use_cache else past_key_values
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class Zamba2ForCausalLM(ZambaForCausalLM):
+ pass
+
+
+@add_start_docstrings(
+ """
+ The Zamba2 Model with a sequence classification head on top (linear layer).
+
+ [`Zamba2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it requires to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """,
+ ZAMBA2_START_DOCSTRING,
+)
+class Zamba2ForSequenceClassification(ZambaForSequenceClassification):
+ pass
+
+
+__all__ = [
+ "Zamba2Config",
+ "Zamba2ForCausalLM",
+ "Zamba2ForSequenceClassification",
+ "Zamba2Model",
+ "Zamba2PreTrainedModel",
+]
diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py
index e3463461ea07e5..31194445337acd 100644
--- a/src/transformers/utils/dummy_pt_objects.py
+++ b/src/transformers/utils/dummy_pt_objects.py
@@ -10345,6 +10345,34 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
+class Zamba2ForCausalLM(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Zamba2ForSequenceClassification(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Zamba2Model(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
+class Zamba2PreTrainedModel(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+
class ZoeDepthForDepthEstimation(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index e85f2663624740..8ab812630dc8eb 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -2323,6 +2323,7 @@ def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1
"mamba",
"xlnet",
"zamba",
+ "zamba2",
)
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
diff --git a/tests/models/zamba/test_modeling_zamba.py b/tests/models/zamba/test_modeling_zamba.py
index a6dd516f98a412..ee47f98a1f4133 100644
--- a/tests/models/zamba/test_modeling_zamba.py
+++ b/tests/models/zamba/test_modeling_zamba.py
@@ -46,7 +46,7 @@
ZambaModel,
)
from transformers.models.zamba.modeling_zamba import (
- HybridMambaAttentionDynamicCache,
+ ZambaHybridDynamicCache,
)
@@ -215,9 +215,7 @@ def create_and_check_decoder_model_past_large_inputs(
# first forward pass
# Attention: Zamba needs the cache to be initialized to return a cache!
- past_key_values = HybridMambaAttentionDynamicCache(
- config, input_ids.shape[0], model.dtype, device=model.device
- )
+ past_key_values = ZambaHybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
outputs = model(
input_ids,
attention_mask=input_mask,
diff --git a/tests/models/zamba2/__init__.py b/tests/models/zamba2/__init__.py
new file mode 100644
index 00000000000000..e69de29bb2d1d6
diff --git a/tests/models/zamba2/test_modeling_zamba2.py b/tests/models/zamba2/test_modeling_zamba2.py
new file mode 100644
index 00000000000000..a64bf543237fa4
--- /dev/null
+++ b/tests/models/zamba2/test_modeling_zamba2.py
@@ -0,0 +1,645 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Testing suite for the PyTorch Zamba model."""
+
+import math
+import tempfile
+import unittest
+
+import pytest
+from parameterized import parameterized
+
+from transformers import AutoTokenizer, Zamba2Config, is_torch_available
+from transformers.testing_utils import (
+ require_bitsandbytes,
+ require_flash_attn,
+ require_torch,
+ require_torch_gpu,
+ slow,
+ torch_device,
+)
+
+from ...generation.test_utils import GenerationTesterMixin
+from ...test_configuration_common import ConfigTester
+from ...test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask
+from ...test_pipeline_mixin import PipelineTesterMixin
+
+
+if is_torch_available():
+ import torch
+
+ from transformers import (
+ Zamba2ForCausalLM,
+ Zamba2ForSequenceClassification,
+ Zamba2Model,
+ )
+ from transformers.models.zamba2.modeling_zamba2 import (
+ Zamba2HybridDynamicCache,
+ )
+
+
+class Zamba2ModelTester:
+ def __init__(
+ self,
+ parent,
+ batch_size=14,
+ seq_length=7,
+ is_training=True,
+ use_input_mask=True,
+ use_labels=True,
+ vocab_size=99,
+ hidden_size=16,
+ mamba_d_state=2,
+ chunk_size=8,
+ mamba_dt_rank="auto",
+ num_hidden_layers=2,
+ num_attention_heads=2,
+ n_mamba_heads=8,
+ mamba_ngroups=8,
+ intermediate_size=4,
+ hidden_act="gelu",
+ hidden_mamba_act="silu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=16,
+ type_sequence_label_size=2,
+ initializer_range=0.02,
+ num_labels=3,
+ num_choices=4,
+ scope=None,
+ layers_block_type=["mamba", "hybrid"],
+ num_mem_blocks=1,
+ use_mem_rope=True,
+ ):
+ self.parent = parent
+ self.batch_size = batch_size
+ self.seq_length = seq_length
+ self.is_training = is_training
+ self.use_input_mask = use_input_mask
+ self.use_labels = use_labels
+ self.vocab_size = vocab_size
+ self.hidden_size = hidden_size
+ self.mamba_dt_rank = mamba_dt_rank
+ self.mamba_d_state = mamba_d_state
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.n_mamba_heads = n_mamba_heads
+ self.mamba_ngroups = mamba_ngroups
+ self.chunk_size = chunk_size
+ self.intermediate_size = intermediate_size
+ self.hidden_act = hidden_act
+ self.hidden_mamba_act = hidden_mamba_act
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.type_sequence_label_size = type_sequence_label_size
+ self.initializer_range = initializer_range
+ self.num_labels = num_labels
+ self.num_choices = num_choices
+ self.scope = scope
+ self.layers_block_type = layers_block_type
+ self.num_mem_blocks = num_mem_blocks
+ self.use_mem_rope = use_mem_rope
+
+ def prepare_config_and_inputs(self):
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
+
+ input_mask = None
+ if self.use_input_mask:
+ input_mask = random_attention_mask([self.batch_size, self.seq_length])
+
+ sequence_labels = None
+ token_labels = None
+ choice_labels = None
+ if self.use_labels:
+ sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
+ token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
+ choice_labels = ids_tensor([self.batch_size], self.num_choices)
+
+ config = self.get_config()
+
+ return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
+
+ def get_config(self):
+ return Zamba2Config(
+ vocab_size=self.vocab_size,
+ hidden_size=self.hidden_size,
+ mamba_dt_rank=self.mamba_dt_rank,
+ mamba_d_state=self.mamba_d_state,
+ num_hidden_layers=self.num_hidden_layers,
+ num_attention_heads=self.num_attention_heads,
+ n_mamba_heads=self.n_mamba_heads,
+ intermediate_size=self.intermediate_size,
+ chunk_size=self.chunk_size,
+ hidden_act=self.hidden_act,
+ mamba_ngroups=self.mamba_ngroups,
+ hidden_mamba_act=self.hidden_mamba_act,
+ hidden_dropout_prob=self.hidden_dropout_prob,
+ attention_probs_dropout_prob=self.attention_probs_dropout_prob,
+ max_position_embeddings=self.max_position_embeddings,
+ type_vocab_size=self.type_vocab_size,
+ is_decoder=True,
+ initializer_range=self.initializer_range,
+ use_mamba_kernels=False,
+ layers_block_type=self.layers_block_type,
+ num_mem_blocks=self.num_mem_blocks,
+ use_mem_rope=self.use_mem_rope,
+ )
+
+ def prepare_config_and_inputs_for_decoder(self):
+ (
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = self.prepare_config_and_inputs()
+
+ config.is_decoder = True
+
+ return (
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ )
+
+ def create_and_check_model(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels):
+ model = Zamba2Model(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids)
+ self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
+
+ def create_and_check_for_causal_lm(
+ self,
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+ model = Zamba2ForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, labels=token_labels)
+ result = model(input_ids, attention_mask=input_mask)
+ result = model(input_ids, labels=token_labels)
+ result = model(input_ids)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
+
+ def create_and_check_decoder_model_past_large_inputs(
+ self,
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ):
+ config.is_decoder = True
+ config.add_cross_attention = False
+ model = Zamba2ForCausalLM(config=config)
+ model.to(torch_device)
+ model.eval()
+
+ # first forward pass
+ # Attention: Zamba2 needs the cache to be initialized to return a cache!
+ past_key_values = Zamba2HybridDynamicCache(config, input_ids.shape[0], model.dtype, device=model.device)
+ outputs = model(
+ input_ids,
+ attention_mask=input_mask,
+ past_key_values=past_key_values,
+ use_cache=True,
+ )
+ past_key_values = outputs.past_key_values
+
+ # create hypothetical multiple next token and extent to next_input_ids
+ next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
+ next_mask = ids_tensor((self.batch_size, 1), vocab_size=2)
+
+ # append to next input_ids and
+ next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ next_attention_mask = torch.cat([input_mask, next_mask], dim=-1)
+
+ output_from_no_past = model(
+ next_input_ids,
+ attention_mask=next_attention_mask,
+ output_hidden_states=True,
+ )["hidden_states"][0]
+ output_from_past = model(
+ next_tokens,
+ attention_mask=next_attention_mask,
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ cache_position=torch.arange(
+ input_ids.shape[1], input_ids.shape[1] + next_tokens.shape[1], device=model.device
+ ),
+ )["hidden_states"][0]
+
+ # select random slice
+ random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
+ output_from_no_past_slice = output_from_no_past[:, -1:, random_slice_idx].detach()
+ output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()
+
+ self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
+
+ # test that outputs are equal for slice
+ self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
+
+ def create_and_check_for_sequence_classification(
+ self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
+ ):
+ config.num_labels = self.num_labels
+ model = Zamba2ForSequenceClassification(config)
+ model.to(torch_device)
+ model.eval()
+ result = model(input_ids, attention_mask=input_mask, labels=sequence_labels)
+ self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
+
+ def prepare_config_and_inputs_for_common(self):
+ config_and_inputs = self.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = config_and_inputs
+ inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
+ return config, inputs_dict
+
+
+@require_torch
+class Zamba2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
+ test_torchscript = False
+ all_model_classes = (
+ (
+ Zamba2Model,
+ Zamba2ForCausalLM,
+ Zamba2ForSequenceClassification,
+ )
+ if is_torch_available()
+ else ()
+ )
+ all_generative_model_classes = (Zamba2ForCausalLM,) if is_torch_available() else ()
+ pipeline_model_mapping = (
+ {
+ "feature-extraction": Zamba2Model,
+ "text-classification": Zamba2ForSequenceClassification,
+ "text-generation": Zamba2ForCausalLM,
+ "zero-shot": Zamba2ForSequenceClassification,
+ }
+ if is_torch_available()
+ else {}
+ )
+ test_headmasking = False
+ test_pruning = False
+
+ def setUp(self):
+ self.model_tester = Zamba2ModelTester(self)
+ self.config_tester = ConfigTester(self, config_class=Zamba2Config, hidden_size=37)
+
+ @unittest.skip("position_ids cannot be used to pad due to Mamba2 layers")
+ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
+ pass
+
+ def test_config(self):
+ self.config_tester.run_common_tests()
+
+ def test_model(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_model(*config_and_inputs)
+
+ def test_for_casual_lm(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
+
+ def test_for_sequence_classification(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
+
+ def test_decoder_model_past_with_large_inputs(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
+ self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
+
+ def test_initialization(self):
+ r"""
+ Overriding the test_initialization test as the A_log and D params of the Mamba block are initialized differently
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ configs_no_init = _config_zero_init(config)
+ for model_class in self.all_model_classes:
+ model = model_class(config=configs_no_init)
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ if "A_log" in name:
+ A = torch.arange(1, config.n_mamba_heads + 1, dtype=torch.float32)[None, :]
+ self.assertTrue(torch.allclose(param.data, torch.log(A), atol=1e-5, rtol=1e-5))
+ elif "D" in name:
+ # check if it's a ones like
+ self.assertTrue(torch.allclose(param.data, torch.ones_like(param.data), atol=1e-5, rtol=1e-5))
+ elif "dt_bias" in name:
+ dt = torch.exp(
+ torch.tensor([0, 1]) * (math.log(config.time_step_max) - math.log(config.time_step_min))
+ + math.log(config.time_step_min)
+ ).clamp(min=config.time_step_floor)
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
+ if param.requires_grad:
+ self.assertTrue(param.data.max().item() <= inv_dt[1])
+ self.assertTrue(param.data.min().item() >= inv_dt[0])
+ else:
+ self.assertIn(
+ ((param.data.mean() * 1e9).round() / 1e9).item(),
+ [0.0, 1.0],
+ msg=f"Parameter {name} of model {model_class} seems not properly initialized",
+ )
+
+ def test_mismatched_shapes_have_properly_initialized_weights(self):
+ r"""
+ Overriding the test_mismatched_shapes_have_properly_initialized_weights test because A_log and D params of the
+ Mamba block are initialized differently and we tested that in test_initialization
+ """
+ self.skipTest("Cumbersome and redundant for Zamba2")
+
+ def test_attention_outputs(self):
+ r"""
+ Overriding the test_attention_outputs test as the Zamba2 model outputs attention only for its attention layers
+ """
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ config.return_dict = True
+
+ seq_len = getattr(self.model_tester, "seq_length", None)
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
+
+ for model_class in self.all_model_classes:
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = False
+ config.return_dict = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+
+ # check that output_attentions also work using config
+ del inputs_dict["output_attentions"]
+ config.output_attentions = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+ attentions = outputs.attentions
+
+ self.assertListEqual(
+ list(attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
+ out_len = len(outputs)
+
+ # Check attention is always last and order is fine
+ inputs_dict["output_attentions"] = True
+ inputs_dict["output_hidden_states"] = True
+ model = model_class(config)
+ model.to(torch_device)
+ model.eval()
+ with torch.no_grad():
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
+
+ added_hidden_states = 1
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
+
+ self_attentions = outputs.attentions
+
+ self.assertListEqual(
+ list(self_attentions[0].shape[-3:]),
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
+ )
+
+ def _get_input_ids_and_config(self):
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
+ (
+ config,
+ input_ids,
+ input_mask,
+ sequence_labels,
+ token_labels,
+ choice_labels,
+ ) = config_and_inputs
+ return config, input_ids, input_mask
+
+ def test_left_padding_compatibility(self):
+ r"""
+ Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
+ effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
+ """
+ import inspect
+ # NOTE: left-padding results in small numerical differences. This is expected.
+ # See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535
+
+ # First, filter out models that don't support left padding - generative and decoder-only.
+ # Zamba2 is a decoder-only architecture
+ decoder_only_classes = self.all_generative_model_classes
+
+ # Then, test left-padding
+ def _prepare_model_kwargs(input_ids, attention_mask, signature):
+ model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
+ if "position_ids" in signature:
+ position_ids = torch.cumsum(attention_mask, dim=-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ model_kwargs["position_ids"] = position_ids
+ if "cache_position" in signature:
+ cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
+ model_kwargs["cache_position"] = cache_position
+ return model_kwargs
+
+ for model_class in decoder_only_classes:
+ config, input_ids, attention_mask = self._get_input_ids_and_config()
+ model = model_class(config).to(torch_device).eval()
+ signature = inspect.signature(model.forward).parameters.keys()
+
+ # Without padding
+ model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
+ next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]
+
+ # With left-padding (length 32)
+ pad_size = (input_ids.shape[0], 32)
+ padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
+ padded_input_ids = torch.cat((padding, input_ids), dim=1)
+ padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
+ model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
+ next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]
+
+ # They should result in very similar logits
+ self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))
+
+ @require_flash_attn
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_fp32_ln(self):
+ r"""
+ Overriding the test_flash_attn_2_fp32_ln test as the Zamba2 model, like Mixtral, doesn't support
+ right padding + use cache with FA2
+ """
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_input = inputs_dict[model.main_input_name]
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Zamba2 does not support right padding + use_cache with FA2.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
+ load_in_4bit=True,
+ )
+
+ for _, param in model.named_parameters():
+ # upcast only layer norms
+ if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
+ param.data = param.data.to(torch.float32)
+
+ _ = model(dummy_input)
+ # with attention mask
+ _ = model(dummy_input, attention_mask=dummy_attention_mask)
+
+ @require_flash_attn
+ @require_torch_gpu
+ @pytest.mark.flash_attn_test
+ @slow
+ def test_flash_attn_2_inference_equivalence_right_padding(self):
+ r"""
+ Overriding the test_flash_attn_2_inference_padding_right test as the Zamba2 model, like Mixtral, doesn't support
+ right padding + use cache with FA2
+ """
+ self.skipTest(reason="Zamba2 flash attention does not support right padding")
+
+ @unittest.skip(reason="Zamba2 has its own special cache type")
+ @parameterized.expand([(1, False), (1, True), (4, False)])
+ def test_new_cache_format(self, num_beams, do_sample):
+ pass
+
+
+@require_torch
+class Zamba2ModelIntegrationTest(unittest.TestCase):
+ model = None
+ tokenizer = None
+
+ @classmethod
+ @slow
+ def setUpClass(cls):
+ model_id = "Zyphra/Zamba2-1.2B"
+ cls.model = Zamba2ForCausalLM.from_pretrained(
+ model_id, torch_dtype=torch.float32, low_cpu_mem_usage=True, revision="PR"
+ )
+ cls.tokenizer = AutoTokenizer.from_pretrained(model_id, revision="PR")
+
+ @slow
+ def test_simple_generate(self):
+ self.model.to(torch_device)
+
+ input_ids = self.tokenizer("Hey how are you doing on this lovely evening?", return_tensors="pt")[
+ "input_ids"
+ ].to(torch_device)
+ out = self.model.generate(input_ids, do_sample=False, max_new_tokens=10)
+ output_sentence = self.tokenizer.decode(out[0, :])
+ self.assertEqual(
+ output_sentence,
+ " Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for",
+ )
+
+ with torch.no_grad():
+ logits = self.model(input_ids=input_ids).logits.to(dtype=torch.float32)
+
+ EXPECTED_LOGITS_NO_GRAD = torch.tensor(
+ [
+ -5.9587, 10.5152, 7.0382, -2.8728, -4.8143, -4.8142, -4.8142, -4.8144,
+ -4.8143, -4.8143, -4.8142, -4.8142, 6.0185, 18.0037, -4.8142, -4.8144,
+ -4.8143, -4.8142, -4.8143, -4.8143, -4.8143, -4.8143, -4.8142, -4.8143,
+ -4.8144, -4.8143, -4.8143, -4.8141, -4.8142, -4.8142, -4.8142, -4.8144,
+ -4.8143, -4.8143, -4.8143, -4.8142, -4.8144, -4.8144, -4.8142, -4.8142
+ ]
+ , dtype=torch.float32) # fmt: skip
+ torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD, rtol=1e-3, atol=1e-3)
+
+ @slow
+ def test_simple_batched_generate_with_padding(self):
+ self.model.to(torch_device)
+
+ inputs = self.tokenizer(
+ ["Hey how are you doing on this lovely evening?", "When did the Roman empire "],
+ padding=True,
+ return_tensors="pt",
+ ).to(torch_device)
+ out = self.model.generate(**inputs, do_sample=False, max_new_tokens=10)
+ output_sentences = self.tokenizer.batch_decode(out)
+ self.assertEqual(
+ output_sentences[0],
+ " Hey how are you doing on this lovely evening?\n\nI'm doing well, thanks for",
+ )
+
+ self.assertEqual(
+ output_sentences[1],
+ "[PAD][PAD][PAD][PAD] When did the Roman empire 1st fall?\nThe Roman Empire fell in",
+ )
+
+ with torch.no_grad():
+ logits = self.model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).logits.to(
+ dtype=torch.float32
+ )
+
+ EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor(
+ [
+ -5.9611, 10.5208, 7.0411, -2.8743, -4.8167, -4.8167, -4.8167, -4.8168,
+ -4.8167, -4.8167, -4.8167, -4.8166, 6.0218, 18.0062, -4.8167, -4.8168,
+ -4.8167, -4.8167, -4.8167, -4.8168, -4.8168, -4.8168, -4.8167, -4.8167,
+ -4.8168, -4.8167, -4.8167, -4.8165, -4.8167, -4.8167, -4.8167, -4.8169,
+ -4.8168, -4.8168, -4.8168, -4.8166, -4.8169, -4.8168, -4.8167, -4.8167
+ ]
+ , dtype=torch.float32) # fmt: skip
+
+ EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor(
+ [
+ 0.1966, 6.3449, 3.8350, -5.7291, -6.5106, -6.5104, -6.5103, -6.5104,
+ -6.5103, -6.5104, -6.5106, -6.5105, 7.8700, 13.5434, -6.5104, -6.5096,
+ -6.5106, -6.5102, -6.5106, -6.5106, -6.5105, -6.5106, -6.5104, -6.5106,
+ -6.5105, -6.5106, -6.5106, -6.5113, -6.5102, -6.5105, -6.5108, -6.5105,
+ -6.5104, -6.5106, -6.5106, -6.5104, -6.5106, -6.5107, -6.5103, -6.5105 ]
+ , dtype=torch.float32) # fmt: skip
+
+ torch.testing.assert_close(logits[0, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_0, rtol=1e-3, atol=1e-3)
+ torch.testing.assert_close(logits[1, -1, :40].cpu(), EXPECTED_LOGITS_NO_GRAD_1, rtol=1e-3, atol=1e-3)
diff --git a/utils/modular_model_converter.textClipping b/utils/modular_model_converter.textClipping
new file mode 100644
index 00000000000000..93a7b0661be5b2
Binary files /dev/null and b/utils/modular_model_converter.textClipping differ