Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multimodal integration - pixtral/llava/qwen2-vl #2170

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
63 changes: 63 additions & 0 deletions examples/llava/lora-7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
base_model: llava-hf/llava-1.5-7b-hf
processor_type: AutoProcessor
strict: false

# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false

chat_template: llava
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

adapter: lora
lora_model_dir:

sequence_len: 8192
pad_to_sequence_len: false

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:

warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
65 changes: 65 additions & 0 deletions examples/pixtral/lora-12b.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
base_model: mistral-community/pixtral-12b
processor_type: AutoProcessor
strict: false

# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false

chat_template: pixtral
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

adapter: lora
lora_model_dir:

sequence_len: 8192
pad_to_sequence_len: false

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'language_model.model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: false # PixtralVisionModel does not support Flash Attention 2.0 yet
eager_attention:

warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>
63 changes: 63 additions & 0 deletions examples/qwen2-vl/lora-7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
base_model: Qwen/Qwen2-VL-7B-Instruct
processor_type: AutoProcessor
strict: false

# these 3 lines are needed for now to handle vision chat templates w images
skip_prepare_dataset: true
remove_unused_columns: false
sample_packing: false

chat_template: qwen2_vl
datasets:
- path: HuggingFaceH4/llava-instruct-mix-vsft
type: chat_template
split: train[:1%]
field_messages: messages
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./outputs/out

adapter: lora
lora_model_dir:

sequence_len: 8192
pad_to_sequence_len: false

lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules: 'model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj'

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: true
fp16:
tf32: true

gradient_checkpointing: true
local_rank:
logging_steps: 1
flash_attention: true
eager_attention:

warmup_ratio: 0.1
evals_per_epoch: 1
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
1 change: 1 addition & 0 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,7 @@ def build_collator(
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
kwargs["chat_template_type"] = self.cfg.chat_template
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
Expand Down
3 changes: 3 additions & 0 deletions src/axolotl/utils/chat_templates.py

Large diffs are not rendered by default.

Loading
Loading