From 60bb571e993b7d73257fb64044726b569fef9403 Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Tue, 21 May 2024 19:38:02 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A8=20[Idefics2]=20Update=20ignore=20i?= =?UTF-8?q?ndex=20(#30898)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update ignore index * Update docs * Update docs --- docs/source/en/model_doc/idefics2.md | 52 +++++++++++++++++++ .../models/idefics2/modeling_idefics2.py | 2 +- 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/docs/source/en/model_doc/idefics2.md b/docs/source/en/model_doc/idefics2.md index 31a7a1cdeb6e7d..5ad56b7b5c525d 100644 --- a/docs/source/en/model_doc/idefics2.md +++ b/docs/source/en/model_doc/idefics2.md @@ -87,6 +87,58 @@ generated_text = processor.batch_decode(generated_text, skip_special_tokens=True print("Generated text:", generated_text) ``` +- During training, it's important to determine which tokens the model should not learn. For Idefics2, this typically comes down to the image and padding tokens. This means that one can create the labels as follows: + +```python +import requests +from PIL import Image +from transformers import Idefics2Processor, Idefics2ForConditionalGeneration +import torch + +url_1 = "http://images.cocodataset.org/val2017/000000039769.jpg" +url_2 = "http://images.cocodataset.org/val2017/000000219578.jpg" + +image_1 = Image.open(requests.get(url_1, stream=True).raw) +image_2 = Image.open(requests.get(url_2, stream=True).raw) +images = [image_1, image_2] + +messages = [{ + "role": "user", + "content": [ + {"type": "text", "text": "What’s the difference between these two images?"}, + {"type": "image"}, + {"type": "image"}, + ], +}, +{ + "role": "assistant", + "content": [ + {"type": "text", "text": "The difference is that one image is about dogs and the other one about cats."}, + ], +}] + +device = "cuda" if torch.cuda.is_available() else "cpu" + +processor = Idefics2Processor.from_pretrained("HuggingFaceM4/idefics2-8b") +model = Idefics2ForConditionalGeneration.from_pretrained("HuggingFaceM4/idefics2-8b") +model.to(device) + +text = processor.apply_chat_template(messages, add_generation_prompt=False) +inputs = processor(images=images, text=text, return_tensors="pt").to(device) + +labels = inputs.input_ids.clone() +labels[labels == processor.tokenizer.pad_token_id] = -100 +labels[labels == model.config.image_token_id] = -100 + +inputs["labels"] = labels + +outputs = model(**inputs) +loss = outputs.loss +loss.backward() +``` + +Do note that when training Idefics2 on multi-turn conversations between a user and an assistant, one typically also sets all the tokens corresponding to the user messages to -100. + ## Model optimizations: Flash Attention The code snippets above showcase inference without any optimization tricks. However, one can drastically speed up the model by leveraging [Flash Attention](../perf_train_gpu_one.md#flash-attention-2), which is a faster implementation of the attention mechanism used inside the model. diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 17ed18f6b99dca..6acabad0635b3f 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1857,7 +1857,7 @@ def forward( shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens - loss_fct = CrossEntropyLoss(ignore_index=self.image_token_id) + loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) if not return_dict: