From 075b762030f765988e8af9772ace9eb909c18e44 Mon Sep 17 00:00:00 2001 From: Raushan Turganbay Date: Fri, 19 Jul 2024 10:08:56 +0500 Subject: [PATCH] Llava: add default chat templates (#31691) * add default chat templates * Update src/transformers/models/llava/processing_llava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/processing_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * more clear docstring and docs * Update docs/source/en/model_doc/llava.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * Update docs/source/en/model_doc/vipllava.md Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> * add tests * remove default templates (see #31733) * load chat template from another file * Update docs/source/en/model_doc/llava_next.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * revert some changes in docs * forgot vipllava * chat template file is not temporary hack * warn if loading from processor * not that file * similarly modify `save_pretrained` * Update tests/models/llava_next/test_processor_llava_next.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/vipllava/test_processor_vipllava.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/vipllava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/vipllava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/processing_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update docs/source/en/model_doc/llava_next.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * fix --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> --- docs/source/en/model_doc/llava.md | 38 +++++++- docs/source/en/model_doc/llava_next.md | 96 +++++++++++++++++-- docs/source/en/model_doc/vipllava.md | 48 ++++++++-- src/transformers/processing_utils.py | 55 ++++++++++- src/transformers/utils/__init__.py | 1 + tests/models/llava/test_processor_llava.py | 17 ++++ .../llava_next/test_processor_llava_next.py | 41 ++++++++ .../vipllava/test_processor_vipllava.py | 41 ++++++++ 8 files changed, 318 insertions(+), 19 deletions(-) create mode 100644 tests/models/llava_next/test_processor_llava_next.py create mode 100644 tests/models/vipllava/test_processor_vipllava.py diff --git a/docs/source/en/model_doc/llava.md b/docs/source/en/model_doc/llava.md index 43eaa41d5d7140..a7e4b4da7f3c5a 100644 --- a/docs/source/en/model_doc/llava.md +++ b/docs/source/en/model_doc/llava.md @@ -40,7 +40,42 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ - Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results. -- For better results, we recommend users to prompt the model with the correct prompt format. Below is a list of prompt formats accepted by each llava checkpoint: +- For better results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows: + +```python +from transformers import AutoProcessor + +processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + +conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What’s shown in this image?"}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "This image shows a red stop sign."},] + }, + { + + "role": "user", + "content": [ + {"type": "text", "text": "Describe the image in more details."}, + ], + }, +] + +text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + +# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images +print(text_prompt) +>>> "USER: \nUSER: Describe the image in more details. ASSISTANT:" +``` + +- If you want to construct a chat prompt yourself, below is a list of prompt formats accepted by each llava checkpoint: [llava-interleave models](https://huggingface.co/collections/llava-hf/llava-interleave-668e19a97da0036aad4a2f19) requires the following format: ```bash @@ -64,6 +99,7 @@ For multiple turns conversation: "USER: \n ASSISTANT: USER: ASSISTANT: USER: ASSISTANT:" ``` + ### Using Flash Attention 2 Flash Attention 2 is an even faster, optimized version of the previous optimization, please refer to the [Flash Attention 2 section of performance docs](https://huggingface.co/docs/transformers/perf_infer_gpu_one). diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index a4a1419ee00ac8..b9d06ff97ffa53 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -46,26 +46,61 @@ The original code can be found [here](https://github.com/haotian-liu/LLaVA/tree/ - We advise users to use `padding_side="left"` when computing batched generation as it leads to more accurate results. Simply make sure to call `processor.tokenizer.padding_side = "left"` before generating. -- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. Below, we list the correct prompt formats to use for the text prompt "What is shown in this image?": +- Note that each checkpoint has been trained with a specific prompt format, depending on which large language model (LLM) was used. You can use the processor's `apply_chat_template` to format your prompts correctly. For that you have to construct a conversation history, passing a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities. Below is an example of how to do that and the list of formats accepted by each checkpoint. -[llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) requires the following format: +We will use [llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-hf/llava-v1.6-mistral-7b-hf) and a conversation history of text and image. Each content field has to be a list of dicts, as follows: + +```python +from transformers import LlavaNextProcessor + +processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-hf/llava-v1.6-mistral-7b-hf") + +conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What’s shown in this image?"}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "This image shows a red stop sign."},] + }, + { + + "role": "user", + "content": [ + {"type": "text", "text": "Describe the image in more details."}, + ], + }, +] +text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + +# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images +print(text_prompt) +>>> "[INST] \nWhat's shown in this image? [/INST] This image shows a red stop sign. [INST] Describe the image in more details. [/INST]" +``` + +- If you want to construct a chat prompt yourself, below is a list of possible formats +. +[llava-v1.6-mistral-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf) requires the following format: ```bash "[INST] \nWhat is shown in this image? [/INST]" ``` [llava-v1.6-vicuna-7b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-7b-hf) and [llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) require the following format: - ```bash "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT:" ``` [llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) requires the following format: - ```bash "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" ``` + ## Usage example ### Single image inference @@ -86,8 +121,17 @@ model.to("cuda:0") # prepare image and text prompt, using the appropriate prompt template url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" image = Image.open(requests.get(url, stream=True).raw) -prompt = "[INST] \nWhat is shown in this image? [/INST]" +conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, +] +prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) inputs = processor(prompt, image, return_tensors="pt").to("cuda:0") # autoregressively complete prompt @@ -120,15 +164,47 @@ image_cats = Image.open(requests.get(url, stream=True).raw) url = "https://huggingface.co/microsoft/kosmos-2-patch14-224/resolve/main/snowman.jpg" image_snowman = Image.open(requests.get(url, stream=True).raw) -# Prepare a batched prompt, where the first one is a multi-turn conversation and the second is not -prompt = [ - "[INST] \nWhat is shown in this image? [/INST] There is a red stop sign in the image. [INST] \nWhat about this image? How many cats do you see [/INST]", - "[INST] \nWhat is shown in this image? [/INST]" +# Prepare a batch of two prompts, where the first one is a multi-turn conversation and the second is not +conversation_1 = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "There is a red stop sign in the image."}, + ], + }, + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What about this image? How many cats do you see?"}, + ], + }, ] +conversation_2 = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, +] + +prompt_1 = processor.apply_chat_template(conversation_1, add_generation_prompt=True) +prompt_2 = processor.apply_chat_template(conversation_2, add_generation_prompt=True) +prompts = [prompt_1, prompt_2] + # We can simply feed images in the order they have to be used in the text prompt # Each "" token uses one image leaving the next for the subsequent "" tokens -inputs = processor(text=prompt, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device) +inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(model.device) # Generate generate_ids = model.generate(**inputs, max_new_tokens=30) diff --git a/docs/source/en/model_doc/vipllava.md b/docs/source/en/model_doc/vipllava.md index 35f2467486a895..b3e76cd292e40a 100644 --- a/docs/source/en/model_doc/vipllava.md +++ b/docs/source/en/model_doc/vipllava.md @@ -26,7 +26,12 @@ The abstract from the paper is the following: *While existing large vision-language multimodal models focus on whole image understanding, there is a prominent gap in achieving region-specific comprehension. Current approaches that use textual coordinates or spatial encodings often fail to provide a user-friendly interface for visual prompting. To address this challenge, we introduce a novel multimodal model capable of decoding arbitrary visual prompts. This allows users to intuitively mark images and interact with the model using natural cues like a "red bounding box" or "pointed arrow". Our simple design directly overlays visual markers onto the RGB image, eliminating the need for complex region encodings, yet achieves state-of-the-art performance on region-understanding tasks like Visual7W, PointQA, and Visual Commonsense Reasoning benchmark. Furthermore, we present ViP-Bench, a comprehensive benchmark to assess the capability of models in understanding visual prompts across multiple dimensions, enabling future research in this domain. Code, data, and model are publicly available.* -Tips: +The original code can be found [here](https://github.com/mu-cai/ViP-LLaVA). + +This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) + + +## Usage tips: - The architecture is similar than llava architecture except that the multi-modal projector takes a set of concatenated vision hidden states and has an additional layernorm layer on that module. @@ -34,22 +39,51 @@ Tips: - Note the model has not been explicitly trained to process multiple images in the same prompt, although this is technically possible, you may experience inaccurate results. -- For better results, we recommend users to prompt the model with the correct prompt format: +- For better results, we recommend users to use the processor's `apply_chat_template()` method to format your prompt correctly. For that you need to construct a conversation history, passing in a plain string will not format your prompt. Each message in the conversation history for chat templates is a dictionary with keys "role" and "content". The "content" should be a list of dictionaries, for "text" and "image" modalities, as follows: + +```python +from transformers import AutoProcessor + +processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf") + +conversation = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What’s shown in this image?"}, + , + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "This image shows a red stop sign."},] + }, + { + + "role": "user", + "content": [ + {"type": "text", "text": "Describe the image in more details."}, + ], + }, +] + +text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + +# Note that the template simply formats your prompt, you still have to tokenize it and obtain pixel values for your images +print(text_prompt) +>>> "###Human: \nWhat’s shown in this image?###Assistant: This image shows a red stop sign.###Human: Describe the image in more details.###Assistant:" +``` +- If you want to construct a chat prompt yourself, below is a list of prompt formats accepted by VipLLaVa checkpoints: ```bash A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: \n###Assistant: ``` For multiple turns conversation: - ```bash A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.###Human: \n###Assistant: ###Human: ###Assistant: ``` -The original code can be found [here](https://github.com/mu-cai/ViP-LLaVA). - -This model was contributed by [Younes Belkada](https://huggingface.co/ybelkada) - ## VipLlavaConfig diff --git a/src/transformers/processing_utils.py b/src/transformers/processing_utils.py index 7062a7699a79f7..24c6af79663652 100644 --- a/src/transformers/processing_utils.py +++ b/src/transformers/processing_utils.py @@ -39,6 +39,7 @@ TruncationStrategy, ) from .utils import ( + CHAT_TEMPLATE_NAME, PROCESSOR_NAME, PushToHubMixin, TensorType, @@ -494,11 +495,21 @@ def save_pretrained(self, save_directory, push_to_hub: bool = False, **kwargs): del attribute.init_kwargs["auto_map"] # If we save using the predefined names, we can load using `from_pretrained` + # plus we save chat_template in its own file output_processor_file = os.path.join(save_directory, PROCESSOR_NAME) + output_chat_template_file = os.path.join(save_directory, CHAT_TEMPLATE_NAME) + + processor_dict = self.to_dict() + chat_template = processor_dict.pop("chat_template", None) + if chat_template is not None: + chat_template_json_string = json.dumps({"chat_template": chat_template}, indent=2, sort_keys=True) + "\n" + with open(output_chat_template_file, "w", encoding="utf-8") as writer: + writer.write(chat_template_json_string) + logger.info(f"chat template saved in {output_chat_template_file}") # For now, let's not save to `processor_config.json` if the processor doesn't have extra attributes and # `auto_map` is not specified. - if set(self.to_dict().keys()) != {"processor_class"}: + if set(processor_dict.keys()) != {"processor_class"}: self.to_json_file(output_processor_file) logger.info(f"processor saved in {output_processor_file}") @@ -557,14 +568,21 @@ def get_processor_dict( is_local = os.path.isdir(pretrained_model_name_or_path) if os.path.isdir(pretrained_model_name_or_path): processor_file = os.path.join(pretrained_model_name_or_path, PROCESSOR_NAME) + chat_template_file = os.path.join(pretrained_model_name_or_path, "chat_template.json") + if os.path.isfile(pretrained_model_name_or_path): resolved_processor_file = pretrained_model_name_or_path + # cant't load chat-template when given a file as pretrained_model_name_or_path + resolved_chat_template_file = None is_local = True elif is_remote_url(pretrained_model_name_or_path): processor_file = pretrained_model_name_or_path resolved_processor_file = download_url(pretrained_model_name_or_path) + # can't load chat-template when given a file url as pretrained_model_name_or_path + resolved_chat_template_file = None else: processor_file = PROCESSOR_NAME + chat_template_file = CHAT_TEMPLATE_NAME try: # Load from local folder or from cache or download from model Hub and cache resolved_processor_file = cached_file( @@ -581,6 +599,24 @@ def get_processor_dict( subfolder=subfolder, _raise_exceptions_for_missing_entries=False, ) + + # Load chat template from a separate json if exists + # because making it part of processor-config break BC. + # Processors in older version do not accept any kwargs + resolved_chat_template_file = cached_file( + pretrained_model_name_or_path, + chat_template_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + user_agent=user_agent, + revision=revision, + subfolder=subfolder, + _raise_exceptions_for_missing_entries=False, + ) except EnvironmentError: # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to # the original exception. @@ -594,6 +630,14 @@ def get_processor_dict( f" directory containing a {PROCESSOR_NAME} file" ) + # Add chat template as kwarg before returning because most models don't have processor config + chat_template = None + if resolved_chat_template_file is not None: + with open(resolved_chat_template_file, "r", encoding="utf-8") as reader: + text = reader.read() + chat_template = json.loads(text)["chat_template"] + kwargs["chat_template"] = chat_template + # Existing processors on the Hub created before #27761 being merged don't have `processor_config.json` (if not # updated afterward), and we need to keep `from_pretrained` work. So here it fallbacks to the empty dict. # (`cached_file` called using `_raise_exceptions_for_missing_entries=False` to avoid exception) @@ -617,6 +661,12 @@ def get_processor_dict( else: logger.info(f"loading configuration file {processor_file} from cache at {resolved_processor_file}") + if "chat_template" in processor_dict and processor_dict["chat_template"] is not None: + logger.warning_once( + "Chat templates should be in a 'chat_template.json' file but found key='chat_template' " + "in the processor's config. Make sure to move your template to its own file." + ) + if not is_local: if "auto_map" in processor_dict: processor_dict["auto_map"] = add_model_info_to_auto_map( @@ -648,6 +698,7 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): """ processor_dict = processor_dict.copy() return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + chat_template = kwargs.pop("chat_template", None) # We have to pop up some unused (but specific) kwargs and then validate that it doesn't contain unused kwargs # If we don't pop, some specific kwargs will raise a warning @@ -659,6 +710,8 @@ def from_args_and_dict(cls, args, processor_dict: Dict[str, Any], **kwargs): unused_kwargs = cls.validate_init_kwargs(processor_config=processor_dict, valid_kwargs=cls.valid_kwargs) processor = cls(*args, **processor_dict) + if chat_template is not None: + setattr(processor, "chat_template", chat_template) # Update processor with kwargs if needed for key in set(kwargs.keys()): diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index e74b710419e565..351ab0cf11ffba 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -239,6 +239,7 @@ FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME PROCESSOR_NAME = "processor_config.json" +CHAT_TEMPLATE_NAME = "chat_template.json" GENERATION_CONFIG_NAME = "generation_config.json" MODEL_CARD_NAME = "modelcard.json" diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index 068971015ec2bc..b668b6f4d6d843 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -28,3 +28,20 @@ def test_can_load_various_tokenizers(self): processor = LlavaProcessor.from_pretrained(checkpoint) tokenizer = AutoTokenizer.from_pretrained(checkpoint) self.assertEqual(processor.tokenizer.__class__, tokenizer.__class__) + + def test_chat_template(self): + processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + expected_prompt = "USER: \nWhat is shown in this image? ASSISTANT:" + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + + formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) + self.assertEquals(expected_prompt, formatted_prompt) diff --git a/tests/models/llava_next/test_processor_llava_next.py b/tests/models/llava_next/test_processor_llava_next.py new file mode 100644 index 00000000000000..0a6eccf555e7b0 --- /dev/null +++ b/tests/models/llava_next/test_processor_llava_next.py @@ -0,0 +1,41 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from transformers.testing_utils import require_vision +from transformers.utils import is_vision_available + + +if is_vision_available(): + from transformers import AutoProcessor + + +@require_vision +class LlavaProcessorTest(unittest.TestCase): + def test_chat_template(self): + processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-7b-hf") + expected_prompt = "USER: \nWhat is shown in this image? ASSISTANT:" + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + + formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) + self.assertEquals(expected_prompt, formatted_prompt) diff --git a/tests/models/vipllava/test_processor_vipllava.py b/tests/models/vipllava/test_processor_vipllava.py new file mode 100644 index 00000000000000..896eddd09c7130 --- /dev/null +++ b/tests/models/vipllava/test_processor_vipllava.py @@ -0,0 +1,41 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from transformers.testing_utils import require_vision +from transformers.utils import is_vision_available + + +if is_vision_available(): + from transformers import AutoProcessor + + +@require_vision +class LlavaProcessorTest(unittest.TestCase): + def test_chat_template(self): + processor = AutoProcessor.from_pretrained("llava-hf/vip-llava-7b-hf") + expected_prompt = "###Human: \nWhat is shown in this image?###Assistant:" + + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + + formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) + self.assertEquals(expected_prompt, formatted_prompt)