Skip to content

Commit

Permalink
Describe diffs with original mamba layer
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio authored Sep 9, 2024
1 parent df8dfd3 commit 4ab88a2
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,13 @@ class ZambaMambaMixer(nn.Module):
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
This module differs from `transformers.models.mamba.modeling_mamba.MambaMixer` in two ways:
- Added multi-head: the output of `self.in_proj` is split into `self.n_mamba_heads` heads, and each head
undergoes an independent forward pass, identical to the original `MambaMixer`, up until the pre-activations of
`self.out_proj`. The pre-activations, coming from different mamba heads, are then concatenated and fed into `self.out_proj`.
- Added `attention_mask` for batched inference: this tensor multiplies input and output of the convolution layer, setting
to zero embeddings associated with `attention_mask == 0` thus preventing the layer to attend to such tokens.
"""

def __init__(self, config: ZambaConfig, layer_idx):
Expand Down

0 comments on commit 4ab88a2

Please sign in to comment.