diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index b52a93ae94841b..16c05a14028eee 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1874,7 +1874,10 @@ def apply_chat_template( **template_kwargs, ) if continue_final_message: - final_message = chat[-1]["content"].strip() + final_message = chat[-1]["content"] + if isinstance(final_message, (list, tuple)): + final_message = final_message[-1]["text"] + final_message = final_message.strip() rendered_chat = rendered_chat[: rendered_chat.rindex(final_message) + len(final_message)].rstrip() rendered.append(rendered_chat) diff --git a/tests/models/llava/test_processor_llava.py b/tests/models/llava/test_processor_llava.py index 06a18061579670..d3a66a16df9a64 100644 --- a/tests/models/llava/test_processor_llava.py +++ b/tests/models/llava/test_processor_llava.py @@ -93,3 +93,24 @@ def test_chat_template(self): formatted_prompt = processor.apply_chat_template(messages, add_generation_prompt=True) self.assertEqual(expected_prompt, formatted_prompt) + + def test_chat_template_with_continue_final_message(self): + processor = LlavaProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") + expected_prompt = "USER: \nDescribe this image. ASSISTANT: There is a dog and" + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "Describe this image."}, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "There is a dog and"}, + ], + }, + ] + prompt = processor.apply_chat_template(messages, continue_final_message=True) + self.assertEqual(expected_prompt, prompt)