Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(minicpm-v): Support MiniCPM-V inference/training pipeline #749

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
606 changes: 606 additions & 0 deletions examples/minicpm_v/finetune/dataset.py

Large diffs are not rendered by default.

328 changes: 328 additions & 0 deletions examples/minicpm_v/finetune/finetune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,328 @@
import json
import os
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple

import numpy as np
import transformers
from transformers import HfArgumentParser

import mindspore as ms
from mindspore import nn
from mindspore.dataset import transforms, vision
from mindspore.train.amp import AMP_BLACK_LIST, _auto_black_list

mindone_lib_path = os.path.abspath(os.path.abspath("../../../"))
sys.path.insert(0, mindone_lib_path)

from dataset import SupervisedDataset
from transformers import AutoTokenizer

from mindone.transformers.mindspore_adapter import MindSporeArguments
from mindone.transformers.models.minicpm_v import MiniCPMV_v2_6
from mindone.transformers.trainer import Trainer
from mindone.transformers.training_args import TrainingArguments


@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="openbmb/MiniCPM-V-2")


@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
eval_data_path: str = field(default=None, metadata={"help": "Path to the evaluation data."})


@dataclass
class LoraArguments:
lora_r: int = 64
lora_alpha: int = 64
lora_dropout: float = 0.05
lora_target_modules: str = r"llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj)"
lora_weight_path: str = ""
lora_bias: str = "none"
q_lora: bool = False
lora_modules_to_save: str = ""
lora_layer_replication: Optional[List[Tuple[int, int]]] = None
lora_layers_to_transform: Optional[List[int]] = None
lora_layers_pattern: Optional[str] = None


@dataclass
class MyArguments(MindSporeArguments, TrainingArguments):
enable_flash_attention: bool = field(default=False)
gradient_checkpointing: bool = field(default=False)
is_distribute: bool = field(default=False)
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_mindspore")
model_max_length: int = field(
default=2048,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
tune_vision: Optional[bool] = field(default=True)
tune_llm: Optional[bool] = field(default=True)
llm_type: str = field(default="minicpm")
use_lora: Optional[bool] = field(default=False)
max_slice_nums: Optional[int] = field(default=9)
distributed: Optional[bool] = field(default=False)
amp_level: Optional[str] = field(default="O0")


local_rank = None


def rank0_print(*args):
if local_rank == 0:
print(*args)


def safe_save_model_for_hf_trainer(trainer, output_dir: str, bias="none"):
"""Collects the state dict and dump to disk."""
if trainer.args.should_save and trainer.args.local_rank == 0:
trainer.save_model(
output_dir,
)


# class ModifiedMapFunction(BaseMapFuction):
# def __call__(self, input_ids, position_ids, labels, attention_mask):
# return trim_and_pad(input_ids), trim_and_pad(position_ids), trim_and_pad(labels), trim_and_pad(attention_mask)


def make_supervised_data_module(
tokenizer: transformers.PreTrainedTokenizer,
data_args,
transform,
data_collator=None,
llm_type="minicpm",
slice_config=None,
patch_size=14,
query_nums=64,
batch_vision=False,
max_length=2048,
) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""

dataset_cls = SupervisedDataset

rank0_print("Loading data...")

train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(
train_json,
transform,
tokenizer,
slice_config=slice_config,
llm_type=llm_type,
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
)

# train_ds = dataset.GeneratorDataset(
# train_dataset,
# column_names=train_dataset.dataset_column_names,
# num_parallel_workers=2,
# shuffle=True,
# python_multiprocessing=False,
# num_shards=rank_size,
# shard_id=rank
# )

if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(
eval_json,
transform,
tokenizer,
slice_config=slice_config,
llm_type=llm_type,
patch_size=patch_size,
query_nums=query_nums,
batch_vision=batch_vision,
max_length=max_length,
)

# eval_ds = dataset.GeneratorDataset(
# eval_dataset,
# column_names=eval_dataset.dataset_column_names,
# num_parallel_workers=8,
# shuffle=False,
# python_multiprocessing=False,
# )
else:
eval_dataset = None

# def trim_and_pad(seq):
# # return pad_sequence([s[:max_length] for s in seq], batch_first=True, padding_value=padding_value)
# max_length = 2048
# return np.stack([s[:max_length] for s in seq])
#
# class ModifiedMapFunction(BaseMapFuction):
# def __call__(self, input_ids, position_ids, labels, attention_mask):
# return trim_and_pad(input_ids), trim_and_pad(position_ids), trim_and_pad(labels), trim_and_pad(attention_mask)

return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)


# def build_transform():
# IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
# IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
# return transforms.Compose(
# [
# vision.ToTensor(),
# vision.Normalize(
# mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, is_hwc=False
# ),
# ]
# )


def build_transform():
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_MEAN
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) # timm.data.IMAGENET_INCEPTION_STD
return transforms.Compose(
[
vision.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, is_hwc=False),
]
)


def get_parameter_number(model):
trainable_params = 0
# for param in model.parameters():
# num_params = param.numel()
# # if using DS Zero 3 and the weights are initialized empty
# if num_params == 0 and hasattr(param, "ds_numel"):
# num_params = param.ds_numel
#
# all_param += num_params
# if param.requires_grad:
# trainable_params += num_params
for param in model.trainable_params():
num_params = np.prod(param.shape)
trainable_params += num_params

return {"Trainable params": trainable_params}


local_rank = 0


def train():
global local_rank
parser = HfArgumentParser((ModelArguments, DataArguments, MyArguments, LoraArguments))

(
model_args,
data_args,
training_args,
lora_args,
) = parser.parse_args_into_dataclasses()

# if getattr(training_args, "deepspeed", None) :
# training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED

compute_dtype = ms.float16 if training_args.fp16 else (ms.bfloat16 if training_args.bf16 else ms.float32)

# if training_args.distributed:
# init()
# data_args.rank, data_args.rank_size, parallel_mode = get_rank(), get_group_size(), context.ParallelMode.DATA_PARALLEL
# context.set_auto_parallel_context(
# device_num=data_args.rank_size, parallel_mode=parallel_mode, gradients_mean=True
# )
# else:
# data_args.rank, data_args.rank_size, parallel_mode = 0, 1, None

local_rank = training_args.local_rank

model = MiniCPMV_v2_6.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
mindspore_dtype=compute_dtype,
)

if training_args.amp_level == "O2":
_auto_black_list(
model,
AMP_BLACK_LIST + [nn.GroupNorm, nn.SiLU],
ms.float16,
)
elif training_args.amp_level == "O3":
model.to_float(ms.float16)

# if training_args.distributed:
# # set grad reducer
# mean = ms.context.get_auto_parallel_context("gradients_mean")
# degree = ms.context.get_auto_parallel_context("device_num")
# grad_reducer = nn.DistributedGradReducer(model.trainable_params(), mean, degree)
# else:
# grad_reducer = None

tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

if not training_args.tune_vision:
# model.vpm.set_train(False)
for param in model.vpm.trainable_params():
param.requires_grad = False
if not training_args.tune_llm:
# model.llm.set_train(False)
for param in model.llm.trainable_params():
param.requires_grad = False

rank0_print(get_parameter_number(model))

llm_type = training_args.llm_type

rank0_print(f"llm_type={llm_type}")

# Load data
if hasattr(model.config, "slice_config"):
model.config.slice_config.max_slice_nums = training_args.max_slice_nums
slice_config = model.config.slice_config.to_dict()
else:
model.config.max_slice_nums = training_args.max_slice_nums
slice_config = model.config.to_dict()

if hasattr(model.config, "batch_vision_input"):
batch_vision = model.config.batch_vision_input
else:
batch_vision = False

transform_func = build_transform()
data_module = make_supervised_data_module(
tokenizer=tokenizer,
data_args=data_args,
transform=transform_func,
data_collator=None,
slice_config=slice_config,
llm_type=llm_type,
patch_size=model.config.patch_size,
query_nums=model.config.query_num,
batch_vision=batch_vision,
max_length=training_args.model_max_length,
)

training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
**data_module,
)

trainer.train()
# trainer.save_state()

safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir, bias=lora_args.lora_bias)


if __name__ == "__main__":
train()
43 changes: 43 additions & 0 deletions examples/minicpm_v/finetune/finetune.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash

MODEL="openbmb/MiniCPM-V-2_6"
# or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="/data3/wcr/mindone/examples/minicpm/finetune/finetune.json"
#EVAL_DATA="path/to/test_data"
LLM_TYPE="qwen2" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3"
MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096

python finetune.py \
--model_name_or_path $MODEL \
--llm_type $LLM_TYPE \
--data_path $DATA \
--remove_unused_columns false \
--label_names "labels" \
--prediction_loss_only false \
--bf16 false \
--bf16_full_eval false \
--fp16 false \
--fp16_full_eval false \
--do_train \
--tune_vision true \
--tune_llm false \
--model_max_length $MODEL_MAX_Length \
--max_slice_nums 9 \
--max_steps 10000 \
--output_dir output/output_minicpmv26 \
--logging_dir output/output_minicpmv26 \
--logging_strategy "steps" \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 1 \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-6 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
> pynative_logs/train_vision.log 2>&1 &
Loading