-
Notifications
You must be signed in to change notification settings - Fork 94
/
smolvlm.py
137 lines (121 loc) · 4.07 KB
/
smolvlm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import os
from PIL import Image
from transformers.image_utils import load_image
USE_LORA = False
USE_QLORA = True
SMOL = True
model_id = "HuggingFaceTB/SmolVLM-Base" if SMOL else "HuggingFaceM4/Idefics3-8B-Llama3"
processor = AutoProcessor.from_pretrained(
model_id
)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 4"
if USE_QLORA or USE_LORA:
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
use_dora=False if USE_QLORA else True,
init_lora_weights="gaussian"
)
lora_config.inference_mode = False
if USE_QLORA:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
quantization_config=bnb_config if USE_QLORA else None,
_attn_implementation="flash_attention_2",
device_map="auto"
)
model.add_adapter(lora_config)
model.enable_adapters()
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
print(model.get_nb_trainable_parameters())
else:
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2",
).to(DEVICE)
# if you'd like to only fine-tune LLM
for param in model.model.vision_model.parameters():
param.requires_grad = False
ds = load_dataset('merve/vqav2-small', trust_remote_code=True)
split_ds = ds["validation"].train_test_split(test_size=0.8)
train_ds = split_ds["train"]
image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")]
def collate_fn(examples):
texts = []
images = []
for example in examples:
image = example["image"]
if image.mode != 'RGB':
image = image.convert('RGB')
question = example["question"]
answer = example["multiple_choice_answer"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
{"type": "image"},
{"type": "text", "text": question}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer}
]
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=False)
texts.append(text.strip())
images.append([image])
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
model_name = model_id.split("/")[-1]
training_args = TrainingArguments(
num_train_epochs=1,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
warmup_steps=50,
learning_rate=1e-4,
weight_decay=0.01,
logging_steps=25,
save_strategy="steps",
save_steps=250,
save_total_limit=1,
optim="paged_adamw_8bit", # for 8-bit, keep this, else adamw_hf
bf16=True, # underlying precision for 8bit
output_dir=f"./{model_name}-vqav2",
hub_model_id=f"{model_name}-vqav2",
report_to="tensorboard",
remove_unused_columns=False,
gradient_checkpointing=True
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_ds,
)
trainer.train()
trainer.push_to_hub()