Skip to content

Commit

Permalink
Add llama3-llava-next-8b to llava_next conversion script (huggingface…
Browse files Browse the repository at this point in the history
…#31395)

* Add llama3-llava-next-8b to llava_next conversion script

Adds support for the lmms-lab/llama3-llava-next-8b model to the
convert_llava_next_weights_to_hf.py script, along with an example
prompt generated from the llava_llama_3 conv_template in the LLaVA-NeXT
repo.

* Exclude <|begin_of_text|> from prompt example

This token gets added automatically, so it should not be included in the
prompt example.

* Add llava-next-72b and llava-next-110b

Adds the Qwen-based LLaVA-Next models to the conversion script, along
with changes to load the models on multiple GPUs for inference.

* Add llama3 and qwen prompt formats to docs

* Chat prompt and padding side left for llama3 batched

* update

* Update src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py

Co-authored-by: amyeroberts <[email protected]>

* remove code

* better naming

---------

Co-authored-by: raushan <[email protected]>
Co-authored-by: Raushan Turganbay <[email protected]>
Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
4 people committed Jul 24, 2024
1 parent 1d57a0e commit 5ba99cb
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 35 deletions.
11 changes: 11 additions & 0 deletions docs/source/en/model_doc/llava_next.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ print(text_prompt)
"<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
```

[llama3-llava-next-8b-hf](https://huggingface.co/llava-hf/llava-next-8b-hf) requires the following format:

```bash
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
```

[llava-next-72b-hf](https://huggingface.co/llava-hf/llava-next-72b-hf) and [llava-next-110b-hf](https://huggingface.co/llava-hf/llava-next-110b-hf) require the following format:

```bash
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"
```

## Usage example

Expand Down
125 changes: 90 additions & 35 deletions src/transformers/models/llava_next/convert_llava_next_weights_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"""

import argparse
import gc
import glob
import json
from pathlib import Path
Expand Down Expand Up @@ -111,6 +112,16 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
elif model_id == "liuhaotian/llava-v1.6-34b":
text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B"
image_token_index = 64000
elif model_id == "lmms-lab/llama3-llava-next-8b":
text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
image_token_index = 128256
elif model_id == "lmms-lab/llava-next-72b":
text_model_id = "Qwen/Qwen1.5-72B-Chat"
image_token_index = 151646
elif model_id == "lmms-lab/llava-next-110b":
text_model_id = "Qwen/Qwen1.5-110B-Chat"
image_token_index = 151646

vision_model_id = data["mm_vision_tower"]

torch.set_default_dtype(torch.float16)
Expand All @@ -120,7 +131,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=use_fast)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)

if model_id == "liuhaotian/llava-v1.6-mistral-7b":
if model_id in ("liuhaotian/llava-v1.6-mistral-7b", "lmms-lab/llama3-llava-next-8b"):
# Mistral-7B doesn't have a padding token set yet
tokenizer.add_special_tokens({"pad_token": "<pad>"})

Expand Down Expand Up @@ -151,28 +162,45 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):

# We add an image token so we resize the model
# Pad to 64 for performance reasons
pad_shape = 64
vocab_size = config.text_config.vocab_size
if model_id == "liuhaotian/llava-v1.6-34b":
# this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and <image>
num_tokens = vocab_size + 3
else:
# this one has 2 additional tokens, namely <image> and <pad>
num_tokens = vocab_size + 2
model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape)
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
tuple(
(dist.sample() for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0]))
),
dim=0,
)
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
dim=0,
)
# Qwen-based models have extra unused space in the vocab size already, so no need to resize
if model_id not in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]:
pad_shape = 64
vocab_size = config.text_config.vocab_size
if model_id == "liuhaotian/llava-v1.6-34b":
# this one has 3 additional tokens, namely <|startoftext|>, <|endoftext|> and <image>
num_tokens = vocab_size + 3
else:
# this one has 2 additional tokens, namely <image> and <pad>
num_tokens = vocab_size + 2
model.resize_token_embeddings(num_tokens, pad_to_multiple_of=pad_shape)
model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack(
tuple(
(
dist.sample()
for _ in range(model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0])
)
),
dim=0,
)
model.language_model.lm_head.weight.data[vocab_size:] = torch.stack(
tuple((dist.sample() for _ in range(model.language_model.lm_head.weight.data[vocab_size:].shape[0]))),
dim=0,
)

print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)

# Make space so we can load the model properly now.
del state_dict
gc.collect()

device = "cuda:2"
model.to(device)
# Load everything back for inference tests in float32 because prev script was written as that
# Though it's mostly loaded in fp16 as original weights are in fp16
model = LlavaNextForConditionalGeneration.from_pretrained(pytorch_dump_folder_path, device_map="auto")
processor = LlavaNextProcessor.from_pretrained(pytorch_dump_folder_path)
device = model.device

# prepare inputs
image = load_image()
Expand All @@ -182,6 +210,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
elif model_id == "liuhaotian/llava-v1.6-34b":
prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
elif model_id == "lmms-lab/llama3-llava-next-8b":
prompt = "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
elif model_id in ["lmms-lab/llava-next-72b", "lmms-lab/llava-next-110b"]:
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|>\n<|im_start|>assistant\n"

inputs = processor(images=image, text=prompt, return_tensors="pt")

# verify inputs
Expand All @@ -194,8 +227,6 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
original_input_ids = torch.load(filepath, map_location="cpu")
# replace -200 by image_token_index (since we use token ID = 32000 for the image token)
original_input_ids[original_input_ids == -200] = image_token_index
print(tokenizer.decode([id for id in original_input_ids.tolist()[0] if id != -200]))

assert original_input_ids[0].tolist() == inputs.input_ids[0].tolist()

elif model_id == "liuhaotian/llava-v1.6-34b":
Expand Down Expand Up @@ -243,6 +274,26 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
dtype=torch.float32,
device=device,
)
elif model_id == "lmms-lab/llama3-llava-next-8b":
expected_slice = torch.tensor(
[[-3.9648, 1.1396, 3.3145], [-5.3594, -1.5654, -1.9619], [-12.3750, -10.6797, -9.3125]],
dtype=torch.float32,
device=device,
)
elif model_id == "lmms-lab/llava-next-72b":
# Not yet checked against reference
expected_slice = torch.tensor(
[[3.7148, 3.9277, 3.4395], [-0.4341, 1.1387, 6.5117], [3.2324, 3.4688, 4.1133]],
dtype=torch.float32,
device=device,
)
elif model_id == "lmms-lab/llava-next-110b":
# Not yet checked against reference
expected_slice = torch.tensor(
[[-2.5449, -1.6738, -2.0371], [1.0811, 3.4961, 5.0312], [1.7803, 2.5137, 2.4277]],
dtype=torch.float32,
device=device,
)
else:
raise ValueError(f"Model {model_id} not supported")

Expand All @@ -268,6 +319,12 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
expected_text = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a radar chart, also known as a spider chart or star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several variables represented:\n\n- MM-Vet\n- LLa-Va-Bench\n- SEED-Bench\n- MM"
elif model_id == "liuhaotian/llava-v1.6-34b":
expected_text = "<|im_start|> system\nAnswer the questions. <|im_start|> user\n\nWhat is shown in this image? <|im_start|> assistant\nThe image appears to be a radar chart, also known as a spider chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular chart, there are several datasets represented by different colors and labeled with various acronyms such as MM-Vet, LLaVA-Bench, SEED-Bench, MM-Bench-CN, MM-"
elif model_id == "lmms-lab/llama3-llava-next-8b":
expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image shows a radar chart, also known as a spider chart or a web chart, which is a type of graph used to display multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "LL'
elif model_id == "lmms-lab/llava-next-72b":
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image displays a radar chart, also known as a spider chart or a star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the value of each variable is represented by the distance from the center of the chart to the point where the axis intersects with the line representing that variable's value.\n\nIn this particular chart, there are several axes"
elif model_id == "lmms-lab/llava-next-110b":
expected_text = "system\nYou are a helpful assistant.\nuser\n\nWhat is shown in this image?\nassistant\nThe image shows a radar chart comparing the performance of different models on various visual question answering (VQA) benchmarks. Each colored line represents a different model, and the distance from the center of the chart indicates the score or performance level of the model on a particular benchmark. The benchmarks are labeled around the edges of the chart, and include VQA v2, GQA, VizWiz, TextVQA, MMBench-CN, MME, and others. The chart allows for a"
else:
raise ValueError(f"Model {model_id} not supported")

Expand All @@ -281,7 +338,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):

inputs = processor(
images=[image, cats_image],
text=[prompt, "[INST] <image>\nHow many cats are there? [/INST]"],
text=[prompt, prompt],
padding=True,
return_tensors="pt",
).to(device)
Expand All @@ -305,16 +362,11 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
print(outputs)

if pytorch_dump_folder_path is not None:
print(f"Saving model and processor for {model_id} to {pytorch_dump_folder_path}")
Path(pytorch_dump_folder_path).mkdir(exist_ok=True)
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)

if push_to_hub:
repo_id = model_id.split("/")[-1]
model.push_to_hub(f"llava-hf/{repo_id}-hf")
processor.push_to_hub(f"llava-hf/{repo_id}-hf")
checkpoint_name = model_id.split("/")[-1]
print(f"Pushing to repo llava-hf/{checkpoint_name}-hf")
model.push_to_hub(f"llava-hf/{checkpoint_name}-hf")
processor.push_to_hub(f"llava-hf/{checkpoint_name}-hf")


if __name__ == "__main__":
Expand All @@ -328,11 +380,14 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
"liuhaotian/llava-v1.6-vicuna-7b",
"liuhaotian/llava-v1.6-vicuna-13b",
"liuhaotian/llava-v1.6-34b",
"lmms-lab/llama3-llava-next-8b",
"lmms-lab/llava-next-72b",
"lmms-lab/llava-next-110b",
],
required=False,
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model directory."
"--pytorch_dump_folder_path", type=str, required=True, help="Path to the output PyTorch model directory."
)
parser.add_argument(
"--push_to_hub", action="store_true", help="Whether or not to push the converted model to the 🤗 hub."
Expand Down

0 comments on commit 5ba99cb

Please sign in to comment.