Skip to content

Commit

Permalink
[docs] update input documentation for MAMBA2 and MISTRAL models to in…
Browse files Browse the repository at this point in the history
…clude cache_position and attention_mask details (huggingface#34322)

* [docs] update input documentation for MAMBA2 and MISTRAL models to include cache_position and attention_mask details

* [docs] correct input documentation for MISTRAL model to reference `input_ids` instead of `decoder_input_ids`

* [docs] clarify cache_position description in MISTRAL model documentation
  • Loading branch information
h3110Fr13nd authored and BernardZach committed Dec 6, 2024
1 parent e694967 commit 6c6066a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
10 changes: 10 additions & 0 deletions src/transformers/models/mamba2/modeling_mamba2.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,16 @@ class Mamba2CausalLMOutput(ModelOutput):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
If `cache_params` is passed, `cache_position` should also be passed.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""


Expand Down
6 changes: 5 additions & 1 deletion src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def _init_weights(self, module):
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
Expand Down Expand Up @@ -666,6 +666,10 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices indicating the position of the input sequence tokens in the sequence. Unlike `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""


Expand Down

0 comments on commit 6c6066a

Please sign in to comment.