Skip to content

Commit

Permalink
Moved mamba weight init into _init_weights
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Sep 9, 2024
2 parents 1a521de + 4ab88a2 commit 767a591
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,13 @@ class ZambaMambaMixer(nn.Module):
A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
and is why Mamba is called **selective** state spaces)
This module differs from `transformers.models.mamba.modeling_mamba.MambaMixer` in two ways:
- Added multi-head: the output of `self.in_proj` is split into `self.n_mamba_heads` heads, and each head
undergoes an independent forward pass, identical to the original `MambaMixer`, up until the pre-activations of
`self.out_proj`. The pre-activations, coming from different mamba heads, are then concatenated and fed into `self.out_proj`.
- Added `attention_mask` for batched inference: this tensor multiplies input and output of the convolution layer, setting
to zero embeddings associated with `attention_mask == 0` thus preventing the layer to attend to such tokens.
"""

def __init__(self, config: ZambaConfig, layer_idx):
Expand Down Expand Up @@ -1342,7 +1349,7 @@ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):

return causal_mask


# Copied from transformers.models.jamba.modeling_jamba.JambaForCausalLM with Jamba->Zamba, JAMBA->ZAMBA
class ZambaForCausalLM(ZambaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

Expand Down

0 comments on commit 767a591

Please sign in to comment.