Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 24, 2024
1 parent 98a6091 commit 97e39da
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
1 change: 1 addition & 0 deletions docs/source/Customization/自定义数据集.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ query-response格式:

微调:
```jsonl
{"messages": [{"role": "user", "content": "浙江的省会在哪?"}, {"role": "assistant", "content": "浙江的省会在杭州。"}]}
{"messages": [{"role": "user", "content": "<image><image>两张图片有什么区别"}, {"role": "assistant", "content": "前一张是小猫,后一张是小狗"}], "images": ["/xxx/x.jpg", "xxx/x.png"]}
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/xxx/x.mp3"]}
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "<image>图片中是什么,<video>视频中是什么"}, {"role": "assistant", "content": "图片中是一个大象,视频中是一只小狗在草地上奔跑"}], "images": ["/xxx/x.jpg"], "videos": ["/xxx/x.mp4"]}
Expand Down
13 changes: 7 additions & 6 deletions docs/source_en/Customization/Custom-dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ For multimodal datasets, the format is the same as the tasks mentioned above. Th
Pre-training:
```jsonl
{"messages": [{"role": "assistant", "content": "Pre-trained text goes here"}]}
{"messages": [{"role": "assistant", "content": "<image> is a puppy, <image> is a kitten"}], "images": ["/xxx/x.jpg", "/xxx/x.png"]}
{"messages": [{"role": "assistant", "content": "<audio> describes how nice the weather is today"}], "audios": ["/xxx/x.wav"]}
{"messages": [{"role": "assistant", "content": "<image> is an elephant, <video> is a lion running"}], "images": ["/xxx/x.jpg"], "videos": ["/xxx/x.mp4"]}
{"messages": [{"role": "assistant", "content": "<image>is a puppy, <image>is a kitten"}], "images": ["/xxx/x.jpg", "/xxx/x.png"]}
{"messages": [{"role": "assistant", "content": "<audio>describes how nice the weather is today"}], "audios": ["/xxx/x.wav"]}
{"messages": [{"role": "assistant", "content": "<image>is an elephant, <video>is a lion running"}], "images": ["/xxx/x.jpg"], "videos": ["/xxx/x.mp4"]}
```

Supervised Fine-tuning:

```jsonl
{"messages": [{"role": "user", "content": "Where is the capital of Zhejiang?"}, {"role": "assistant", "content": "The capital of Zhejiang is Hangzhou."}]}
{"messages": [{"role": "user", "content": "<image><image>What is the difference between the two images?"}, {"role": "assistant", "content": "The first one is a kitten, and the second one is a puppy."}], "images": ["/xxx/x.jpg", "xxx/x.png"]}
{"messages": [{"role": "user", "content": "<audio>What did the audio say?"}, {"role": "assistant", "content": "The weather is really nice today."}], "audios": ["/xxx/x.mp3"]}
{"messages": [{"role": "system", "content": "You are a helpful and harmless assistant."}, {"role": "user", "content": "<image>What is in the image, <video>What is in the video?"}, {"role": "assistant", "content": "The image shows an elephant, and the video shows a puppy running on the grass."}], "images": ["/xxx/x.jpg"], "videos": ["/xxx/x.mp4"]}
Expand All @@ -93,7 +94,7 @@ The data format for RLHF can refer to the format used for pure text large models
For grounding (object detection) tasks, SWIFT supports two methods:
1. Maintain consistency with the above multimodal dataset format, adding special characters in the dataset, for example:
```jsonl
{"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "<image> Find a <ref> elephant </ref>"}, {"role": "assistant", "content": "<box>(200,450),(500,800)</box>"}], "images": ["/xxx/x.jpg"]}
{"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "<image>Find a <ref> elephant </ref>"}, {"role": "assistant", "content": "<box>(200,450),(500,800)</box>"}], "images": ["/xxx/x.jpg"]}
```
With this type of data, please note:
- Grounding tasks often require special characters. You need to determine which model to use, read the model paper to identify special characters for grounding tasks, and combine the data accordingly.
Expand All @@ -104,9 +105,9 @@ With this type of data, please note:

```jsonl
# Object detection
{"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "<image> Identify <bbox>"}, {"role": "assistant", "content": "<ref-object>"}], "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]"}
{"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "<image>Identify <bbox>"}, {"role": "assistant", "content": "<ref-object>"}], "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [138, 136, 235, 359], \"bbox_type\": \"real\", \"image\": 0}]"}
# Grounding to multiple bboxes
{"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "<image> Find <ref-object>"}, {"role": "assistant", "content": "<bbox>"}], "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [[138, 136, 235, 359], [1,2,3,4]], \"bbox_type\": \"real\", \"image\": 0}]"}
{"messages": [{"role": "system", "content": "You are a useful and harmless assistant"}, {"role": "user", "content": "<image>Find <ref-object>"}, {"role": "assistant", "content": "<bbox>"}], "images": ["/coco2014/train2014/COCO_train2014_000000001507.jpg"], "objects": "[{\"caption\": \"guy in red\", \"bbox\": [[138, 136, 235, 359], [1,2,3,4]], \"bbox_type\": \"real\", \"image\": 0}]"}
```

This format adds the objects field, which includes:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def pre_forward_hook(self, model: nn.Module, args, kwargs):
kwargs.pop('input_ids', None)

if isinstance(model, PeftModel):
parameters = inspect.signature(model.base_model.model.forward).parameters
parameters = inspect.signature(model.model.forward).parameters
else:
parameters = inspect.signature(model.forward).parameters
if 'position_ids' not in parameters:
Expand Down
6 changes: 3 additions & 3 deletions swift/llm/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset
'lorap_lr_ratio': args.lorap_lr_ratio,
'init_lora_weights': args.init_weights,
}

task_type = 'CAUSAL_LM' if args.num_labels is None else 'SEQ_CLS'
if args.train_type in ('lora', 'longlora'):
if args.use_swift_lora:
lora_config = LoRAConfig(lora_dtype=args.lora_dtype, **lora_kwargs)
model = Swift.prepare_model(model, lora_config)
logger.info(f'lora_config: {lora_config}')
elif args.tuner_backend == 'peft':
lora_config = LoraConfig(task_type='CAUSAL_LM', lora_dtype=args.lora_dtype, **lora_kwargs)
lora_config = LoraConfig(task_type=task_type, lora_dtype=args.lora_dtype, **lora_kwargs)
if args.init_weights == 'lora-ga':
try:
import lora_ga
Expand Down Expand Up @@ -211,7 +211,7 @@ def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset
lora_kwargs.pop('lorap_lr_ratio', None)
lora_kwargs['rank_pattern'] = None
adalora_config = AdaLoraConfig(
task_type='CAUSAL_LM',
task_type=task_type,
**lora_kwargs,
target_r=args.adalora_target_r,
init_r=args.adalora_init_r,
Expand Down
16 changes: 11 additions & 5 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import shutil
import time
from contextlib import contextmanager
from functools import wraps
from copy import copy
from functools import wraps
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -250,20 +250,26 @@ def _save_checkpoint(self, *args, **kwargs):

@contextmanager
def _patch_loss_function(self):
if not hasattr(self.model, 'loss_function'):
model = self.model
if isinstance(model, PeftModel):
model = model.model
model_cls = model.__class__
if not hasattr(model_cls, 'loss_function'):
yield
return

loss_function = self.model.loss_function
loss_function = model.loss_function
_old_loss_function = model_cls.loss_function

@staticmethod
@wraps(loss_function)
def new_loss_function(logits, labels, **kwargs):
labels = labels.to(logits.device) # fix device_map
return loss_function(logits=logits, labels=labels, **kwargs)

self.model.loss_function = new_loss_function
model_cls.loss_function = new_loss_function
yield
self.model.loss_function = loss_function
model_cls.loss_function = _old_loss_function

def train(self, *args, **kwargs):
if self.model.model_meta.is_multimodal:
Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=No
else:
unwrapped_model = self.accelerator.unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
model_name = unwrapped_model.base_model.model._get_name()
model_name = unwrapped_model.model._get_name()
else:
model_name = unwrapped_model._get_name()
# User-defined compute_loss function
Expand Down

0 comments on commit 97e39da

Please sign in to comment.