Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llama3-llava-next-8b to llava_next conversion script #31395

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
elif model_id == "liuhaotian/llava-v1.6-34b":
text_model_id = "NousResearch/Nous-Hermes-2-Yi-34B"
image_token_index = 64000
elif model_id == "lmms-lab/llama3-llava-next-8b":
text_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
image_token_index = 128256

vision_model_id = data["mm_vision_tower"]

torch.set_default_dtype(torch.float16)
Expand All @@ -120,7 +124,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
tokenizer = AutoTokenizer.from_pretrained(text_model_id, use_fast=use_fast)
tokenizer.add_tokens(AddedToken("<image>", special=True, normalized=False), special_tokens=True)

if model_id == "liuhaotian/llava-v1.6-mistral-7b":
if model_id == "liuhaotian/llava-v1.6-mistral-7b" or model_id == "lmms-lab/llama3-llava-next-8b":
zucchini-nlp marked this conversation as resolved.
Show resolved Hide resolved
# Mistral-7B doesn't have a padding token set yet
tokenizer.add_special_tokens({"pad_token": "<pad>"})

Expand Down Expand Up @@ -182,6 +186,9 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>\nWhat is shown in this image? ASSISTANT:"
elif model_id == "liuhaotian/llava-v1.6-34b":
prompt = "<|im_start|>system\nAnswer the questions.<|im_end|><|im_start|>user\n<image>\nWhat is shown in this image?<|im_end|><|im_start|>assistant\n"
elif model_id == "lmms-lab/llama3-llava-next-8b":
prompt = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.<|eot_id|><|start_header_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nWhat is shown in this image?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

inputs = processor(images=image, text=prompt, return_tensors="pt")

# verify inputs
Expand Down Expand Up @@ -243,6 +250,12 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
dtype=torch.float32,
device=device,
)
elif model_id == "lmms-lab/llama3-llava-next-8b":
expected_slice = torch.tensor(
[[-3.9648, 1.1396, 3.3145], [-3.9648, 1.1396, 3.3145], [-3.6426, -0.0081, -0.1266]],
dtype=torch.float32,
device=device
)
else:
raise ValueError(f"Model {model_id} not supported")

Expand All @@ -268,6 +281,8 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
expected_text = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: \nWhat is shown in this image? ASSISTANT: The image appears to be a radar chart, also known as a spider chart or star chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular radar chart, there are several variables represented:\n\n- MM-Vet\n- LLa-Va-Bench\n- SEED-Bench\n- MM"
elif model_id == "liuhaotian/llava-v1.6-34b":
expected_text = "<|im_start|> system\nAnswer the questions. <|im_start|> user\n\nWhat is shown in this image? <|im_start|> assistant\nThe image appears to be a radar chart, also known as a spider chart, which is a graphical method of displaying multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point.\n\nIn this particular chart, there are several datasets represented by different colors and labeled with various acronyms such as MM-Vet, LLaVA-Bench, SEED-Bench, MM-Bench-CN, MM-"
elif model_id == "lmms-lab/llama3-llava-next-8b":
expected_text = 'system\n\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.user\n\n\nWhat is shown in this image?assistant\n\n\nThe image appears to be a radar chart, also known as a spider chart or a web chart, which is a type of graph that displays multivariate data in the form of a two-dimensional chart of three or more quantitative variables represented on axes starting from the same point. Each axis represents a different variable, and the values are plotted along each axis and connected to form a polygon.\n\nIn this particular radar chart, there are several axes labeled with different variables, such as "MM-Vet," "'
else:
raise ValueError(f"Model {model_id} not supported")

Expand Down Expand Up @@ -328,6 +343,7 @@ def convert_llava_to_hf(model_id, pytorch_dump_folder_path, push_to_hub=False):
"liuhaotian/llava-v1.6-vicuna-7b",
"liuhaotian/llava-v1.6-vicuna-13b",
"liuhaotian/llava-v1.6-34b",
"lmms-lab/llama3-llava-next-8b",
],
required=False,
)
Expand Down