From 251a2409c694c29ee28e66c954670c483cf54961 Mon Sep 17 00:00:00 2001 From: James Thewlis Date: Tue, 23 Jul 2024 01:12:16 -0400 Subject: [PATCH] Add llama3-llava-next-8b to llava_next conversion script (#31395) * Add llama3-llava-next-8b to llava_next conversion script Adds support for the lmms-lab/llama3-llava-next-8b model to the convert_llava_next_weights_to_hf.py script, along with an example prompt generated from the llava_llama_3 conv_template in the LLaVA-NeXT repo. * Exclude <|begin_of_text|> from prompt example This token gets added automatically, so it should not be included in the prompt example. * Add llava-next-72b and llava-next-110b Adds the Qwen-based LLaVA-Next models to the conversion script, along with changes to load the models on multiple GPUs for inference. * Add llama3 and qwen prompt formats to docs * Chat prompt and padding side left for llama3 batched * update * Update src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * remove code * better naming --------- Co-authored-by: raushan Co-authored-by: Raushan Turganbay Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- docs/source/en/model_doc/llava_next.md | 11 ++ .../convert_llava_next_weights_to_hf.py | 125 +++++++++++++----- 2 files changed, 101 insertions(+), 35 deletions(-) diff --git a/docs/source/en/model_doc/llava_next.md b/docs/source/en/model_doc/llava_next.md index b9d06ff97ffa53..9e7caa37d7b9bc 100644 --- a/docs/source/en/model_doc/llava_next.md +++ b/docs/source/en/model_doc/llava_next.md @@ -100,6 +100,17 @@ print(text_prompt) "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" ``` +[llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llava-next-8b-hf) requires the following format: + +```bash +"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" +``` + +[llava-next-72b-hf](https://huggingface.co/llava-hf/llava-next-72b-hf) and [llava-next-110b-hf](https://huggingface.co/llava-hf/llava-next-110b-hf) require the following format: + +```bash +"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n" +``` ## Usage example diff --git a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py index 2c8aefe39dc255..06edc5c9b1adbc 100644 --- a/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py +++ b/src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py @@ -24,6 +24,7 @@ """ import argparse +import gc import glob import json from pathlib import Path @@ -111,6 +112,16 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): elif model_id == "liuhaotian/llava-v1.6-34b": text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B" image_token_index = 64000 + elif model_id == "lmms-lab/llama3-llava-next-8b": + text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct" + image_token_index = 128256 + elif model_id == "lmms-lab/llava-next-72b": + text_model_id = "Qwen/Qwen1.5-72B-Chat" + image_token_index = 151646 + elif model_id == "lmms-lab/llava-next-110b": + text_model_id = "Qwen/Qwen1.5-110B-Chat" + image_token_index = 151646 + vision_model_id = data["mm_vision_tower"] torch.set_default_dtype(torch.float16) @@ -120,7 +131,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=use_fast) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) - if model_id == "liuhaotian/llava-v1.6-mistral-7b": + if model_id in ("liuhaotian/llava-v1.6-mistral-7b", "lmms-lab/llama3-llava-next-8b"): # Mistral-7B doesn't have a padding token set yet tokenizer.add_special_tokens({"pad_token": ""}) @@ -151,28 +162,45 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): # We add an image token so we resize the model # Pad to 64 for performance reasons - pad_shape = 64 - vocab_size = config.text_config.vocab_size - if model_id == "liuhaotian/llava-v1.6-34b": - # this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and - num_tokens = vocab_size + 3 - else: - # this one has 2 additional tokens, namely and - num_tokens = vocab_size + 2 - model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape) - model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack( - tuple( - (dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0])) - ), - dim=0, - ) - model.language_model.lm_head.weight.data[vocab_size:] = torch.stack( - tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))), - dim=0, - ) + # Qwen-based models have extra unused space in the vocab size already, so no need to resize + if model_id not in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]: + pad_shape = 64 + vocab_size = config.text_config.vocab_size + if model_id == "liuhaotian/llava-v1.6-34b": + # this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and + num_tokens = vocab_size + 3 + else: + # this one has 2 additional tokens, namely and + num_tokens = vocab_size + 2 + model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape) + model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack( + tuple( + ( + dist.sample() + for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]) + ) + ), + dim=0, + ) + model.language_model.lm_head.weight.data[vocab_size:] = torch.stack( + tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))), + dim=0, + ) + + print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}") + Path(pytorch_dump_folder_path).mkdir(exist_ok=True) + model.save_pretrained(pytorch_dump_folder_path) + processor.save_pretrained(pytorch_dump_folder_path) + + # Make space so we can load the model properly now. + del state_dict + gc.collect() - device = "cuda:2" - model.to(device) + # Load everything back for inference tests in float32 because prev script was written as that + # Though it's mostly loaded in fp16 as original weights are in fp16 + model = LlavaNextForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, device_map="auto") + processor = LlavaNextProcessor.from_pretrained(pytorch_dump_folder_path) + device = model.device # prepare inputs image = load_image() @@ -182,6 +210,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT:" elif model_id == "liuhaotian/llava-v1.6-34b": prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n" + elif model_id == "lmms-lab/llama3-llava-next-8b": + prompt = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + elif model_id in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]: + prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n" + inputs = processor(images=image, text=prompt, return_tensors="pt") # verify inputs @@ -194,8 +227,6 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): original_input_ids = torch.load(filepath, map_location="cpu") # replace -200 by image_token_index (since we use token ID = 32000 for the image token) original_input_ids[original_input_ids == -200] = image_token_index - print(tokenizer.decode([id for id in original_input_ids.tolist()[0] if id != -200])) - assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist() elif model_id == "liuhaotian/llava-v1.6-34b": @@ -243,6 +274,26 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): dtype=torch.float32, device=device, ) + elif model_id == "lmms-lab/llama3-llava-next-8b": + expected_slice = torch.tensor( + [[-3.9648, 1.1396, 3.3145], [-5.3594, -1.5654, -1.9619], [-12.3750, -10.6797, -9.3125]], + dtype=torch.float32, + device=device, + ) + elif model_id == "lmms-lab/llava-next-72b": + # Not yet checked against reference + expected_slice = torch.tensor( + [[3.7148, 3.9277, 3.4395], [-0.4341, 1.1387, 6.5117], [3.2324, 3.4688, 4.1133]], + dtype=torch.float32, + device=device, + ) + elif model_id == "lmms-lab/llava-next-110b": + # Not yet checked against reference + expected_slice = torch.tensor( + [[-2.5449, -1.6738, -2.0371], [1.0811, 3.4961, 5.0312], [1.7803, 2.5137, 2.4277]], + dtype=torch.float32, + device=device, + ) else: raise ValueError(f"Model {model_id} not supported") @@ -268,6 +319,12 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): expected_text = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a radar chart, also known as a spider chart or star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several variables represented:\n\n- MM-Vet\n- LLa-Va-Bench\n- SEED-Bench\n- MM" elif model_id == "liuhaotian/llava-v1.6-34b": expected_text = "<|im_start|> system\nAnswer the questions. <|im_start|> user\n\nWhat is shown in this image? <|im_start|> assistant\nThe image appears to be a radar chart, also known as a spider chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular chart, there are several datasets represented by different colors and labeled with various acronyms such as MM-Vet, LLaVA-Bench, SEED-Bench, MM-Bench-CN, MM-" + elif model_id == "lmms-lab/llama3-llava-next-8b": + expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image shows a radar chart, also known as a spider chart or a web chart, which is a type of graph used to display multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "LL' + elif model_id == "lmms-lab/llava-next-72b": + expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image displays a radar chart, also known as a spider chart or a star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the value of each variable is represented by the distance from the center of the chart to the point where the axis intersects with the line representing that variable's value.\n\nIn this particular chart, there are several axes" + elif model_id == "lmms-lab/llava-next-110b": + expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart comparing the performance of different models on various visual question answering (VQA) benchmarks. Each colored line represents a different model, and the distance from the center of the chart indicates the score or performance level of the model on a particular benchmark. The benchmarks are labeled around the edges of the chart, and include VQA v2, GQA, VizWiz, TextVQA, MMBench-CN, MME, and others. The chart allows for a" else: raise ValueError(f"Model {model_id} not supported") @@ -281,7 +338,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): inputs = processor( images=[image, cats_image], - text=[prompt, "[INST] \nHow many cats are there? [/INST]"], + text=[prompt, prompt], padding=True, return_tensors="pt", ).to(device) @@ -305,16 +362,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True) print(outputs) - if pytorch_dump_folder_path is not None: - print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}") - Path(pytorch_dump_folder_path).mkdir(exist_ok=True) - model.save_pretrained(pytorch_dump_folder_path) - processor.save_pretrained(pytorch_dump_folder_path) - if push_to_hub: - repo_id = model_id.split("/")[-1] - model.push_to_hub(f"llava-hf/{repo_id}-hf") - processor.push_to_hub(f"llava-hf/{repo_id}-hf") + checkpoint_name = model_id.split("/")[-1] + print(f"Pushing to repo llava-hf/{checkpoint_name}-hf") + model.push_to_hub(f"llava-hf/{checkpoint_name}-hf") + processor.push_to_hub(f"llava-hf/{checkpoint_name}-hf") if __name__ == "__main__": @@ -328,11 +380,14 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False): "liuhaotian/llava-v1.6-vicuna-7b", "liuhaotian/llava-v1.6-vicuna-13b", "liuhaotian/llava-v1.6-34b", + "lmms-lab/llama3-llava-next-8b", + "lmms-lab/llava-next-72b", + "lmms-lab/llava-next-110b", ], required=False, ) parser.add_argument( - "--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory." + "--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model directory." ) parser.add_argument( "--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."