From 5aace7c1452640cc82404c017a6f7f1590282a4c Mon Sep 17 00:00:00 2001 From: Tomer Asida Date: Wed, 17 Apr 2024 00:40:04 +0300 Subject: [PATCH] fix docs --- docs/source/en/_toctree.yml | 2 ++ .../models/jamba/modeling_jamba.py | 32 ++++++++++++++----- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 7daf91c99df087..38e4763d0a4cae 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -382,6 +382,8 @@ title: HerBERT - local: model_doc/ibert title: I-BERT + - local: model_doc/jamba + title: Jamba - local: model_doc/jukebox title: Jukebox - local: model_doc/led diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 20f6820dc94d9b..16b7c644d3633a 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -225,11 +225,11 @@ class HybridMambaAttentionDynamicCache(DynamicCache): This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states` and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor - For attention layers, `key_cache` and `value_cache` have a shape of `[batch_size, num_heads, seq_len, head_dim]`, - while `conv_states` and `ssm_states` have a shape of `[batch_size, 0]` (empty tensors). - For mamba layers, `key_cache` and `value_cache` have a shape of `[batch_size, 0]` (empty tensors), - while `conv_states` represents the convolution state and has a shape of `[batch_size, d_inner, d_conv]`, - and `ssm_states` represents the ssm state and has a shape of `[batch_size, d_inner, d_state]`. + For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`, + while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors). + For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors), + while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`, + and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`. """ def __init__(self, config, batch_size, dtype=torch.float16, device=None): @@ -1304,9 +1304,9 @@ def _init_weights(self, module): A HybridMambaAttentionDynamicCache object containing pre-computed hidden-states (keys and values in the self-attention blocks and convolution and ssm states in the mamba blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - Key and value cache tensors have shape `[batch_size, num_heads, seq_len, head_dim]`. - Convolution and ssm states tensors have shape `[batch_size, d_inner, d_conv]` and - `[batch_size, d_inner, d_state]` respectively. + Key and value cache tensors have shape `(batch_size, num_heads, seq_len, head_dim)`. + Convolution and ssm states tensors have shape `(batch_size, d_inner, d_conv)` and + `(batch_size, d_inner, d_state)` respectively. See the `HybridMambaAttentionDynamicCache` class for more details. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that @@ -1605,6 +1605,22 @@ def forward( can save memory, which becomes pretty significant for long sequences. Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, JambaForCausalLM + + >>> model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("ai21labs/Jamba-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions