Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix continue_final_message for image-text-to-text chat templates #34236

Merged
merged 2 commits into from
Oct 22, 2024

Conversation

yonigozlan
Copy link
Member

What does this PR do?

The content field for an image-text-to-text model is a list, which is not currently taken into account when continue_final_message is set to True in tokenization_utils_base.
Split from image-text-to-text PR

Reproduce error:

from transformers import LlavaProcessor, LlavaForConditionalGeneration
import torch
from PIL import Image
import requests

processor = LlavaProcessor.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf")

model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-interleave-qwen-0.5b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda:0")


# Define a chat history and use `apply_chat_template` to get correctly formatted prompt
# Each value in "content" has to be a list of dicts with types ("text", "image")
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "There is a dog and"},
        ],
    },
]
prompt = processor.apply_chat_template(messages, continue_final_message=True)

inputs = processor(text=prompt, return_tensors="pt").to("cuda:0").to(torch.float16)

# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=100)

print(processor.decode(output[0], skip_special_tokens=True))

@zucchini-nlp @ArthurZucker

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Maybe also cc @Rocketknight1

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, LGTM too!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Can we have a small test please? 🤗

@yonigozlan
Copy link
Member Author

Added one test for llava processor :). I could add one for every vlms processor that use chat template, but as they all use the same underlying apply_chat_template, I thought it was not worth the diffs. Wdyt?

@yonigozlan yonigozlan force-pushed the fix-tokenization-base branch from 64523c2 to afc298f Compare October 21, 2024 21:18
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 😉

@yonigozlan yonigozlan merged commit e7c3fa7 into huggingface:main Oct 22, 2024
23 of 25 checks passed
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…gingface#34236)

* fix continue_final_message for vlms

* Add one test for vlms continue_final_message chat template
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
…gingface#34236)

* fix continue_final_message for vlms

* Add one test for vlms continue_final_message chat template
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants