From df8dfd302c173ff5880eb85eafeccf8a4316d59d Mon Sep 17 00:00:00 2001 From: pglorio <85982602+pglorio@users.noreply.github.com> Date: Mon, 9 Sep 2024 15:41:44 -0700 Subject: [PATCH 1/2] Update ZambaForCausalLM --- src/transformers/models/zamba/modeling_zamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 45b3c7009a0a06..45ef249276b788 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -1327,7 +1327,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position): return causal_mask - +# Copied from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA class ZambaForCausalLM(ZambaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] From 4ab88a29e8b17696f1b013d4f9d68a5e59fcf83d Mon Sep 17 00:00:00 2001 From: pglorio <85982602+pglorio@users.noreply.github.com> Date: Mon, 9 Sep 2024 16:03:41 -0700 Subject: [PATCH 2/2] Describe diffs with original mamba layer --- src/transformers/models/zamba/modeling_zamba.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 45ef249276b788..5a14d9d7a0fa0a 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -508,6 +508,13 @@ class ZambaMambaMixer(nn.Module): A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective) ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4, and is why Mamba is called **selective** state spaces) + + This module differs from `transformers.models.mamba.modeling_mamba.MambaMixer` in two ways: + - Added multi-head: the output of `self.in_proj` is split into `self.n_mamba_heads` heads, and each head + undergoes an independent forward pass, identical to the original `MambaMixer`, up until the pre-activations of + `self.out_proj`. The pre-activations, coming from different mamba heads, are then concatenated and fed into `self.out_proj`. + - Added `attention_mask` for batched inference: this tensor multiplies input and output of the convolution layer, setting + to zero embeddings associated with `attention_mask == 0` thus preventing the layer to attend to such tokens. """ def __init__(self, config: ZambaConfig, layer_idx):