Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Idefics2 generation erroring with flash_attention_2 #32237

Closed
1 of 4 tasks
tctrautman opened this issue Jul 26, 2024 · 2 comments · Fixed by #32241
Closed
1 of 4 tasks

Idefics2 generation erroring with flash_attention_2 #32237

tctrautman opened this issue Jul 26, 2024 · 2 comments · Fixed by #32241
Labels

Comments

@tctrautman
Copy link

System Info

- `transformers` version: 4.44.0.dev0
- Platform: Linux-5.4.0-155-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.24.2
- Safetensors version: 0.4.3
- Accelerate version: 0.33.0
- Accelerate config: 	not found
- PyTorch version (GPU?): 2.1.1+cu121 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: yes (see script)
- GPU type: NVIDIA RTX A6000

Who can help?

@zucchini-nlp

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

The below script is the same as the one that is included on the Idefics2 blog post, with three additional lines added within AutoModelForVision2Seq.from_pretrained, with comments to note the new lines.

import requests
import torch
from PIL import Image

from transformers import AutoProcessor, AutoModelForVision2Seq
from transformers.image_utils import load_image

DEVICE = "cuda:0"
dtype = torch.bfloat16

# Note that passing the image urls (instead of the actual pil images) to the processor is also possible
image1 = load_image("https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg")
image2 = load_image("https://cdn.britannica.com/59/94459-050-DBA42467/Skyline-Chicago.jpg")
image3 = load_image("https://cdn.britannica.com/68/170868-050-8DDE8263/Golden-Gate-Bridge-San-Francisco.jpg")


processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b")
model = AutoModelForVision2Seq.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    attn_implementation="flash_attention_2", # This is a new line
    torch_dtype=dtype, # This is a new line
    device_map=DEVICE, # This is a new line
).to(DEVICE)


# Create inputs
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What do we see in this image?"},
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "In this image, we can see the city of New York, and more specifically the Statue of Liberty."},
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "And how about this image?"},
        ]
    },
]
prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=prompt, images=[image1, image2], return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}


# Generate
generated_ids = model.generate(**inputs, max_new_tokens=500)
generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)

print(generated_texts)

When this block of code is run, it will yield the below error.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 55
     51 inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
     54 # Generate
---> 55 generated_ids = model.generate(**inputs, max_new_tokens=500)
     56 generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
     58 print(generated_texts)

File /usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1990, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1982     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1983         input_ids=input_ids,
   1984         expand_size=generation_config.num_return_sequences,
   1985         is_encoder_decoder=self.config.is_encoder_decoder,
   1986         **model_kwargs,
   1987     )
   1989     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 1990     result = self._sample(
   1991         input_ids,
   1992         logits_processor=prepared_logits_processor,
   1993         logits_warper=prepared_logits_warper,
   1994         stopping_criteria=prepared_stopping_criteria,
   1995         generation_config=generation_config,
   1996         synced_gpus=synced_gpus,
   1997         streamer=streamer,
   1998         **model_kwargs,
   1999     )
   2001 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2002     # 11. prepare logits warper
   2003     prepared_logits_warper = (
   2004         self._get_logits_warper(generation_config, device=input_ids.device)
   2005         if generation_config.do_sample
   2006         else None
   2007     )

File /usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:2933, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2930 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   2932 # forward pass to get next token
-> 2933 outputs = self(**model_inputs, return_dict=True)
   2935 if synced_gpus and this_peer_finished:
   2936     continue  # don't waste resources running the code we don't need

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1575, in Idefics2ForConditionalGeneration.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   1572 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1574 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1575 outputs = self.model(
   1576     input_ids=input_ids,
   1577     attention_mask=attention_mask,
   1578     position_ids=position_ids,
   1579     past_key_values=past_key_values,
   1580     inputs_embeds=inputs_embeds,
   1581     pixel_values=pixel_values,
   1582     pixel_attention_mask=pixel_attention_mask,
   1583     image_hidden_states=image_hidden_states,
   1584     use_cache=use_cache,
   1585     output_attentions=output_attentions,
   1586     output_hidden_states=output_hidden_states,
   1587     return_dict=return_dict,
   1588 )
   1590 hidden_states = outputs[0]
   1591 logits = self.lm_head(hidden_states)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/idefics2/modeling_idefics2.py:1408, in Idefics2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, pixel_values, pixel_attention_mask, image_hidden_states, use_cache, output_attentions, output_hidden_states, return_dict)
   1399 if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
   1400     # When we generate, we don't want to replace the potential image_token_id that we generated by images
   1401     # that simply don't exist
   1402     inputs_embeds = self.inputs_merger(
   1403         input_ids=input_ids,
   1404         inputs_embeds=inputs_embeds,
   1405         image_hidden_states=image_hidden_states,
   1406     )
-> 1408 outputs = self.text_model(
   1409     inputs_embeds=inputs_embeds,
   1410     attention_mask=attention_mask,
   1411     position_ids=position_ids,
   1412     past_key_values=past_key_values,
   1413     output_attentions=output_attentions,
   1414     output_hidden_states=output_hidden_states,
   1415     return_dict=return_dict,
   1416 )
   1418 if return_legacy_cache and use_cache:
   1419     outputs.past_key_values = outputs.past_key_values.to_legacy_cache()

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:805, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    794     layer_outputs = self._gradient_checkpointing_func(
    795         decoder_layer.__call__,
    796         hidden_states,
   (...)
    802         cache_position,
    803     )
    804 else:
--> 805     layer_outputs = decoder_layer(
    806         hidden_states,
    807         attention_mask=causal_mask,
    808         position_ids=position_ids,
    809         past_key_value=past_key_values,
    810         output_attentions=output_attentions,
    811         use_cache=use_cache,
    812         cache_position=cache_position,
    813     )
    815 hidden_states = layer_outputs[0]
    817 if use_cache:

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:546, in MistralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
    543 hidden_states = self.input_layernorm(hidden_states)
    545 # Self Attention
--> 546 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    547     hidden_states=hidden_states,
    548     attention_mask=attention_mask,
    549     position_ids=position_ids,
    550     past_key_value=past_key_value,
    551     output_attentions=output_attentions,
    552     use_cache=use_cache,
    553     cache_position=cache_position,
    554     **kwargs,
    555 )
    556 hidden_states = residual + hidden_states
    558 # Fully Connected

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /usr/local/lib/python3.10/dist-packages/transformers/models/mistral/modeling_mistral.py:379, in MistralFlashAttention2.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    376 key_states = key_states.transpose(1, 2)
    377 value_states = value_states.transpose(1, 2)
--> 379 attn_output = _flash_attention_forward(
    380     query_states,
    381     key_states,
    382     value_states,
    383     attention_mask,
    384     q_len,
    385     position_ids=position_ids,
    386     dropout=dropout_rate,
    387     sliding_window=getattr(self.config, "sliding_window", None),
    388     use_top_left_mask=self._flash_attn_uses_top_left_mask,
    389     is_causal=self.is_causal,
    390 )
    392 attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
    393 attn_output = self.o_proj(attn_output)

File /usr/local/lib/python3.10/dist-packages/transformers/modeling_flash_attention_utils.py:278, in _flash_attention_forward(query_states, key_states, value_states, attention_mask, query_length, is_causal, dropout, position_ids, softmax_scale, sliding_window, use_top_left_mask, softcap, deterministic)
    275     cu_seqlens_q, cu_seqlens_k = cu_seq_lens
    276     max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
--> 278     attn_output = flash_attn_varlen_func(
    279         query_states,
    280         key_states,
    281         value_states,
    282         cu_seqlens_q=cu_seqlens_q,
    283         cu_seqlens_k=cu_seqlens_k,
    284         max_seqlen_q=max_seqlen_in_batch_q,
    285         max_seqlen_k=max_seqlen_in_batch_k,
    286         dropout_p=dropout,
    287         softmax_scale=softmax_scale,
    288         causal=causal,
    289         **flash_kwargs,
    290     )
    292     attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
    294 else:

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:1124, in flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_attn_probs, block_table)
   1051 def flash_attn_varlen_func(
   1052     q,
   1053     k,
   (...)
   1067     block_table=None,
   1068 ):
   1069     """dropout_p should be set to 0.0 during evaluation
   1070     Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
   1071     than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
   (...)
   1122             pattern (negative means that location was dropped, nonnegative means it was kept).
   1123     """
-> 1124     return FlashAttnVarlenFunc.apply(
   1125         q,
   1126         k,
   1127         v,
   1128         cu_seqlens_q,
   1129         cu_seqlens_k,
   1130         max_seqlen_q,
   1131         max_seqlen_k,
   1132         dropout_p,
   1133         softmax_scale,
   1134         causal,
   1135         window_size,
   1136         softcap,
   1137         alibi_slopes,
   1138         deterministic,
   1139         return_attn_probs,
   1140         block_table,
   1141     )

File /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:539, in Function.apply(cls, *args, **kwargs)
    536 if not torch._C._are_functorch_transforms_active():
    537     # See NOTE: [functorch vjp and autograd interaction]
    538     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 539     return super().apply(*args, **kwargs)  # type: ignore[misc]
    541 if cls.setup_context == _SingleLevelFunction.setup_context:
    542     raise RuntimeError(
    543         "In order to use an autograd.Function with functorch transforms "
    544         "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
    545         "staticmethod. For more details, please see "
    546         "https://pytorch.org/docs/master/notes/extending.func.html"
    547     )

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:620, in FlashAttnVarlenFunc.forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, deterministic, return_softmax, block_table)
    618 if softmax_scale is None:
    619     softmax_scale = q.shape[-1] ** (-0.5)
--> 620 out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
    621     q,
    622     k,
    623     v,
    624     cu_seqlens_q,
    625     cu_seqlens_k,
    626     max_seqlen_q,
    627     max_seqlen_k,
    628     dropout_p,
    629     softmax_scale,
    630     causal=causal,
    631     window_size=window_size,
    632     softcap=softcap,
    633     alibi_slopes=alibi_slopes,
    634     return_softmax=return_softmax and dropout_p > 0,
    635     block_table=block_table,
    636 )
    637 ctx.save_for_backward(
    638     q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
    639 )
    640 ctx.dropout_p = dropout_p

File /usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py:90, in _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax, block_table, leftpad_k, seqused_k)
     70 def _flash_attn_varlen_forward(
     71     q,
     72     k,
   (...)
     87     seqused_k=None,
     88 ):
     89     q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
---> 90     out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
     91         q,
     92         k,
     93         v,
     94         None,
     95         cu_seqlens_q,
     96         cu_seqlens_k,
     97         seqused_k,
     98         leftpad_k,
     99         block_table,
    100         alibi_slopes,
    101         max_seqlen_q,
    102         max_seqlen_k,
    103         dropout_p,
    104         softmax_scale,
    105         False,
    106         causal,
    107         window_size[0],
    108         window_size[1],
    109         softcap,
    110         return_softmax,
    111         None,
    112     )
    113     # if out.isnan().any() or softmax_lse.isnan().any():
    114     #     breakpoint()
    115     return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state

RuntimeError: batch size must be positive

Expected behavior

I'd expect the above script to generate without error (a similar one did earlier this week, which now yields the same error.)

I believe one of these two issues might be related to this issue:

@tctrautman tctrautman added the bug label Jul 26, 2024
@zucchini-nlp
Copy link
Member

Hey! Indeed Flash-attention seems to be broken in the last release caused by #31629. I located the reason and will work on fix, in the meanwhile you can downgrade transformers version to at most v.4.42.4, and try generating again :)

@tctrautman
Copy link
Author

Thank you, @zucchini-nlp!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants