Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LLaVA: index error when computing extended_attention_mask #28197

Closed
2 of 4 tasks
TideDra opened this issue Dec 22, 2023 · 2 comments · Fixed by #28032
Closed
2 of 4 tasks

LLaVA: index error when computing extended_attention_mask #28197

TideDra opened this issue Dec 22, 2023 · 2 comments · Fixed by #28032

Comments

@TideDra
Copy link

TideDra commented Dec 22, 2023

System Info

  • transformers version: 4.36.2
  • Platform: Linux-5.15.0-1042-azure-x86_64-with-glibc2.35
  • Python version: 3.10.13
  • Huggingface_hub version: 0.19.4
  • Safetensors version: 0.4.1
  • Accelerate version: 0.21.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.0+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed

Who can help?

@younesbelkad

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I'm evaluating llava-1.5-7b-hf on MM-Vet using batch generation with use_cache=True, here is my script:

import json
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration,AutoTokenizer
from torch.utils.data import Dataset,DataLoader
import torch
import os
from tqdm import tqdm
DATA_ROOT = "/mnt/gozhang/code/LLaVA/playground/data/eval/mm-vet"
processor = AutoProcessor.from_pretrained("/mnt/gozhang/ckpts/llava-1.5-7b-hf")
tokenizer = AutoTokenizer.from_pretrained("/mnt/gozhang/ckpts/llava-1.5-7b-hf")
processor.tokenizer.pad_token = processor.tokenizer.bos_token
class MMVetDataset(Dataset):
    def __init__(self,data_root) -> None:
        super().__init__()
        self.data_root = data_root
        with open(os.path.join(data_root, "mm-vet.json"), "r") as f:
            data = json.load(f)
        self.data = [(k,v) for k,v in data.items()]
    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return {'id':self.data[index][0],
                'image':os.path.join(self.data_root,'images',self.data[index][1]['imagename']),
                'question':"USER: <image>\n"+self.data[index][1]['question']+" ASSISTANT:"}

def collator(batch):
    ids = [b['id'] for b in batch]
    questions = [b['question'] for b in batch]
    images = [Image.open(b['image']) for b in batch]
    inputs = processor(text=questions,images=images,return_tensors="pt",padding=True)
    return ids,inputs

model = LlavaForConditionalGeneration.from_pretrained("/mnt/gozhang/ckpts/llava-1.5-7b-hf",torch_dtype=torch.float16)
model.to('cuda')
#model.to(torch.float16)
dataset = MMVetDataset(DATA_ROOT)
dataloader = DataLoader(dataset,batch_size=16,collate_fn=collator)
results = {}
bar = tqdm(total=len(dataset))
model.eval()
with torch.inference_mode():
    for ids, inputs in dataloader:
        inputs.to('cuda')
        inputs['pixel_values'] = inputs['pixel_values'].half()
        outputs = model.generate(**inputs,temperature=0.2,do_sample=True,max_new_tokens=1024,use_cache=True)
        input_token_len = inputs['input_ids'].shape[1]
        responses=tokenizer.batch_decode(outputs[:, input_token_len:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
        for id,res in zip(ids,responses):
            results[id]=res
        bar.update(len(responses))
with open('mmvet_result.json','w') as f:
    json.dump(results,f,indent=4)

However, it occasionally raises RuntimeError: CUDA error: device-side assert triggered when computing extended_attention_mask. This error happens randomly during the whole evaluation, sometimes happens in the third batch, sometimes in the last batch, etc.

I print some shapes in the model.forward() method and I think the extended_attention_mask is wrongly computed.

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        pixel_values: torch.FloatTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        vision_feature_layer: Optional[int] = None,
        vision_feature_select_strategy: Optional[str] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        vision_feature_layer = (
            vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
        )
        vision_feature_select_strategy = (
            vision_feature_select_strategy
            if vision_feature_select_strategy is not None
            else self.config.vision_feature_select_strategy
        )

        if inputs_embeds is None:
            # 1. Extra the input embeddings
            inputs_embeds = self.get_input_embeddings()(input_ids)

            # 2. Merge text and images
            if pixel_values is not None and input_ids.shape[1] != 1:
                image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
                # this is not memory efficient at all (output_hidden_states=True) will save all the hidden stated.
                selected_image_feature = image_outputs.hidden_states[vision_feature_layer]

                if vision_feature_select_strategy == "default":
                    selected_image_feature = selected_image_feature[:, 1:]
                elif vision_feature_select_strategy == "full":
                    selected_image_feature = selected_image_feature
                else:
                    raise ValueError(
                        f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
                    )

                image_features = self.multi_modal_projector(selected_image_feature)
                inputs_embeds, attention_mask, position_ids = self._merge_input_ids_with_image_features(
                    image_features, inputs_embeds, input_ids, attention_mask, position_ids
                )
                if labels is None:
                    labels = torch.full_like(attention_mask, self.config.ignore_index).to(torch.long)
            else:
                # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
                # generation with cache
                if past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
                    # Retrieve the first layer to inspect the logits and mask out the hidden states
                    # that are set to 0
                    first_layer_past_key_value = past_key_values[0][0][:, 0, :, 0]
                    batch_index, non_attended_tokens = torch.where(first_layer_past_key_value == 0)
                    # Get the target length
                    target_seqlen = first_layer_past_key_value.shape[-1] + 1

                    extended_attention_mask = torch.ones(
                        (attention_mask.shape[0], target_seqlen - attention_mask.shape[1]),
                        dtype=attention_mask.dtype,
                        device=attention_mask.device,
                    )

                    # Zero-out the places where we don't need to attend
                    print(extended_attention_mask.shape)    # torch.Size([16,575])
                    print(len(past_key_values))    # 32
                    print(len(past_key_values[0]))    # 2
                    print(past_key_values[0][0].shape)    # torch.Size([16,32,688,128])
                    print(attention_mask.shape)    # torch.Size(16,114)
                    print(batch_index)    #tensor([2],device='cuda:0')
                    print(non_attended_tokens)  #tensor([687],device='cuda:0')
                    try:
                        extended_attention_mask[batch_index, non_attended_tokens] = 0
                    except:
                        pdb.set_trace()

                    attention_mask = torch.cat((attention_mask, extended_attention_mask), dim=1)
                    position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
####Following code is ignored

Apparently, extended_attention_mask has a constant sequence length of 575 (target_seqlen - attention_mask.shape[1]), which I think is roughly the number of image tokens, while the index of non_attended_tokens may exceed this length and then raise the CUDA error. Maybe the sequence length of extended_attention_mask should just be target_seqlen, and don't need to be concatenate with attention_mask? Honestly I don't understand the code here, it's really weird.

Expected behavior

The generation should always work fine when using cache.

@amyeroberts
Copy link
Collaborator

Hi @TideDra, thanks for reporting this!

There's an ongoing PR which aims to address this issue: #28032

cc @younesbelkada for reference

@younesbelkada
Copy link
Contributor

Thanks! Yes, I second what @amyeroberts said, I will put that PR as high prio and merge it asap

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants