Skip to content

Commit

Permalink
fix docs
Browse files Browse the repository at this point in the history
  • Loading branch information
tomeras91 committed Apr 16, 2024
1 parent c9f094a commit 5aace7c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@
title: HerBERT
- local: model_doc/ibert
title: I-BERT
- local: model_doc/jamba
title: Jamba
- local: model_doc/jukebox
title: Jukebox
- local: model_doc/led
Expand Down
32 changes: 24 additions & 8 deletions src/transformers/models/jamba/modeling_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,11 +225,11 @@ class HybridMambaAttentionDynamicCache(DynamicCache):
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
For attention layers, `key_cache` and `value_cache` have a shape of `[batch_size, num_heads, seq_len, head_dim]`,
while `conv_states` and `ssm_states` have a shape of `[batch_size, 0]` (empty tensors).
For mamba layers, `key_cache` and `value_cache` have a shape of `[batch_size, 0]` (empty tensors),
while `conv_states` represents the convolution state and has a shape of `[batch_size, d_inner, d_conv]`,
and `ssm_states` represents the ssm state and has a shape of `[batch_size, d_inner, d_state]`.
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
"""

def __init__(self, config, batch_size, dtype=torch.float16, device=None):
Expand Down Expand Up @@ -1304,9 +1304,9 @@ def _init_weights(self, module):
A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the
self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
Key and value cache tensors have shape `[batch_size, num_heads, seq_len, head_dim]`.
Convolution and ssm states tensors have shape `[batch_size, d_inner, d_conv]` and
`[batch_size, d_inner, d_state]` respectively.
Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`.
Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and
`(batch_size, d_inner, d_state)` respectively.
See the `HybridMambaAttentionDynamicCache` class for more details.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
Expand Down Expand Up @@ -1605,6 +1605,22 @@ def forward(
can save memory, which becomes pretty significant for long sequences.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, JambaForCausalLM
>>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down

0 comments on commit 5aace7c

Please sign in to comment.