Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 24, 2024
1 parent dfb18cf commit 288df5b
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 54 deletions.
16 changes: 12 additions & 4 deletions swift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SwiftTuners, LongLoRAConfig, LongLoRA, LongLoRAModelType, SCETuning, SCETuningConfig)
from .trainers import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy,
SchedulerType, ShardedDDPOption, TrainingArguments, Seq2SeqTrainingArguments, Trainer,
Seq2SeqTrainer, SequenceClassificationTrainer)
Seq2SeqTrainer)
from .utils import get_logger
else:
_import_structure = {
Expand All @@ -29,9 +29,17 @@
'Swift', 'SwiftTuners', 'LongLoRAConfig', 'LongLoRA', 'LongLoRAModelType', 'SCETuning', 'SCETuningConfig'
],
'trainers': [
'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend', 'HubStrategy', 'IntervalStrategy', 'SchedulerType',
'ShardedDDPOption', 'TrainingArguments', 'Seq2SeqTrainingArguments', 'Trainer', 'Seq2SeqTrainer',
'SequenceClassificationTrainer'
'EvaluationStrategy',
'FSDPOption',
'HPSearchBackend',
'HubStrategy',
'IntervalStrategy',
'SchedulerType',
'ShardedDDPOption',
'TrainingArguments',
'Seq2SeqTrainingArguments',
'Trainer',
'Seq2SeqTrainer',
],
'utils': ['get_logger']
}
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _init_device(self):

def get_template(self, processor: 'Processor') -> 'Template':
template_kwargs = self.get_template_kwargs()
template = get_template(self.template, processor, use_chat_template=self.use_chat_template, **template_kwargs)
template = get_template(self.template, processor, **template_kwargs)
logger.info(f'default_system: {template.template_meta.default_system}')
return template

Expand Down
3 changes: 2 additions & 1 deletion swift/llm/argument/base_args/template_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,6 @@ def get_template_kwargs(self):
'tools_prompt': self.tools_prompt,
'loss_scale': self.loss_scale,
'sequence_parallel_size': self.sequence_parallel_size,
'template_backend': self.template_backend
'template_backend': self.template_backend,
'use_chat_template': self.use_chat_template
}
18 changes: 17 additions & 1 deletion swift/llm/dataset/dataset/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,11 +398,27 @@ def preprocess(self, row):
subsets=hc3_subsets,
tags=['text-generation', 'classification', '🔥']))

hc3_subset_names = ['finance', 'medicine']
hc3_subsets: List[SubsetDataset] = []
for hc3_subset_name in hc3_subset_names:
hc3_subsets.append(
SubsetDataset(
name=hc3_subset_name,
subset=hc3_subset_name,
preprocess_func=HC3Preprocessor(),
))
hc3_subsets.append(
SubsetDataset(
name=f'{hc3_subset_name}_cls',
subset=hc3_subset_name,
preprocess_func=HC3ClsPreprocessor(),
))

register_dataset(
DatasetMeta(
ms_dataset_id='simpleai/HC3',
hf_dataset_id='Hello-SimpleAI/HC3',
subsets=['finance', 'medicine'],
subsets=hc3_subsets,
preprocess_func=HC3Preprocessor(),
tags=['text-generation', 'classification', '🔥']))

Expand Down
23 changes: 15 additions & 8 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(

self.mode: Literal['pt', 'vllm', 'lmdeploy', # infer
'train', 'rlhf', 'kto' # train
] = 'pt'
'seq_cls'] = 'pt'
self._handles = []
self._deepspeed_initialize = None

Expand Down Expand Up @@ -201,7 +201,7 @@ def encode(self,
encoded = Template._encode(self, inputs)
for key in ['images', 'audios', 'videos']:
encoded[key] = getattr(inputs, key)
elif self.mode in {'pt', 'train'}:
elif self.mode in {'pt', 'train', 'seq_cls'}:
encoded = self._encode(inputs)
elif self.mode == 'rlhf':
encoded = self._rlhf_encode(inputs)
Expand Down Expand Up @@ -696,7 +696,7 @@ def pre_forward_hook(self, model: nn.Module, args, kwargs):
def is_training(self):
return self.mode not in {'vllm', 'lmdeploy', 'pt'}

def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'train', 'rlhf', 'kto']) -> None:
def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None:
self.mode = mode

def register_post_encode_hook(self, models: List[nn.Module]) -> None:
Expand Down Expand Up @@ -741,6 +741,8 @@ def data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int
return self._kto_data_collator(batch, padding_to=padding_to)
elif self.mode in {'pt', 'train'}:
return self._data_collator(batch, padding_to=padding_to)
elif self.mode == 'seq_cls':
return self._seq_cls_data_collator(batch, padding_to=padding_to)

@staticmethod
def _fetch_inputs_startswith(batch: List[Dict[str, Any]], prefix: str) -> List[Dict[str, Any]]:
Expand Down Expand Up @@ -797,6 +799,16 @@ def _kto_data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optiona
res['label'] = label
return res

def _seq_cls_data_collator(self,
batch: List[Dict[str, Any]],
*,
padding_to: Optional[int] = None) -> Dict[str, Any]:
labels = [b['label'] for b in batch if b.get('label') is not None]
res = self._data_collator(batch, padding_to=padding_to)
if labels:
res['labels'] = torch.tensor(labels)
return res

def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[int] = None) -> Dict[str, Any]:
"""
Args:
Expand Down Expand Up @@ -862,11 +874,6 @@ def _data_collator(self, batch: List[Dict[str, Any]], *, padding_to: Optional[in
pixel_values_videos = [b['pixel_values_videos'] for b in batch if b.get('pixel_values_videos') is not None]
if len(pixel_values_videos) > 0:
res['pixel_values_videos'] = torch.concat(pixel_values_videos)
# sequence_classification
if self.is_training:
label = [b['label'] for b in batch if b.get('label') is not None]
if label:
res['label'] = label
if use_torchacc() or self.sequence_parallel_size > 1:
res = self._torchacc_xtuner_data_collator(res, padding_to, self.tokenizer, padding_side)
return res
Expand Down
12 changes: 7 additions & 5 deletions swift/llm/train/sft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from functools import partial
from typing import List, Union
from typing import List, Optional, Union

from datasets import Dataset as HfDataset

Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(self, args: Union[List[str], TrainArguments, None] = None) -> None:
self.args.save_args()
self.train_msg = {}
self._prepare_model_tokenizer()
self._prepare_template(True)
self._prepare_template()
self._prepare_callbacks()

def _prepare_gradient_checkpointing(self):
Expand Down Expand Up @@ -71,14 +71,16 @@ def _prepare_model_tokenizer(self):
self._prepare_generation_config()
self._prepare_gradient_checkpointing()

def _prepare_template(self, use_chat_template: bool) -> None:
def _prepare_template(self, use_chat_template: Optional[bool] = None) -> None:
args = self.args
template_kwargs = args.get_template_kwargs()
template = get_template(args.template, self.processor, use_chat_template=use_chat_template, **template_kwargs)
if use_chat_template is not None:
template_kwargs['use_chat_template'] = use_chat_template
template = get_template(args.template, self.processor, **template_kwargs)
logger.info(f'default_system: {template.template_meta.default_system}')
if template.use_model:
template.model = self.model
template.set_mode('train')
template.set_mode('train' if args.num_labels is None else 'seq_cls')
self.template = template

def _get_dataset(self):
Expand Down
4 changes: 2 additions & 2 deletions swift/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .rlhf_trainer import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer,
RewardTrainer)
from .trainer_factory import TrainerFactory
from .trainers import Seq2SeqTrainer, Trainer, SequenceClassificationTrainer
from .trainers import Seq2SeqTrainer, Trainer
from .mixin import SwiftMixin

else:
Expand All @@ -33,7 +33,7 @@
'rlhf_trainer':
['CPOTrainer', 'DPOTrainer', 'KTOTrainer', 'ORPOTrainer', 'RLHFTrainerMixin', 'PPOTrainer', 'RewardTrainer'],
'trainer_factory': ['TrainerFactory'],
'trainers': ['Seq2SeqTrainer', 'Trainer', 'SequenceClassificationTrainer'],
'trainers': ['Seq2SeqTrainer', 'Trainer'],
'mixin': ['SwiftMixin'],
}

Expand Down
21 changes: 19 additions & 2 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import shutil
import time
from contextlib import contextmanager
from copy import copy
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -246,6 +247,23 @@ def _save_checkpoint(self, *args, **kwargs):
logger.info(f'Saving model checkpoint to {self.state.last_model_checkpoint}')
return result

@contextmanager
def _patch_loss_function(self):
if not hasattr(self.model, 'loss_function'):
yield
return

loss_function = model.loss_function

@warps(loss_function)
def new_loss_function(logits, labels, **kwargs):
labels = labels.to(logits.device) # fix device_map
return loss_function(logits=logits, labels=labels, **kwargs)

self.model.loss_function = new_loss_function
yield
self.model.loss_function = loss_function

def train(self, *args, **kwargs):
if self.model.model_meta.is_multimodal:
models = list(
Expand All @@ -255,9 +273,8 @@ def train(self, *args, **kwargs):
]))
self.template.register_post_encode_hook(models)
logger.info(f'Successfully registered post_encode hook: {[model.__class__.__name__ for model in models]}')
self.model_accepts_loss_kwargs = True # fix transformers>=4.46.2
self._save_initial_model(self.args.output_dir)
with self.hub.patch_hub():
with self.hub.patch_hub(), self._patch_loss_function():
return super().train(*args, **kwargs)
self.template.remove_post_encode_hook()

Expand Down
2 changes: 1 addition & 1 deletion swift/trainers/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
class TrainerFactory:
TRAINER_MAPPING = {
'seq2seq': 'swift.trainers.Seq2SeqTrainer',
'seq_cls': 'swift.trainers.SequenceClassificationTrainer',
'seq_cls': 'swift.trainers.Trainer',
'dpo': 'swift.trainers.DPOTrainer',
'orpo': 'swift.trainers.ORPOTrainer',
'kto': 'swift.trainers.KTOTrainer',
Expand Down
26 changes: 1 addition & 25 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class Seq2SeqTrainer(TorchAccMixin, SwiftMixin, HfSeq2SeqTrainer):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True # fix transformers>=4.46.2
if self.args.predict_with_generate:
from swift.llm import PtEngine
self.infer_engine = PtEngine.from_model_template(
Expand Down Expand Up @@ -165,28 +166,3 @@ def _compute_token_acc(self, outputs, labels) -> None:
if k not in self._custom_metrics:
self._custom_metrics[k] = MeanMetric(nan_value=None)
self._custom_metrics[k].update(v)


class SequenceClassificationTrainer(Trainer):
"""A trainer for text-classification task"""

# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.label_names = ['labels']

def compute_loss(self, model, inputs, return_outputs=None, **kwargs):
if 'label' in inputs:
inputs['labels'] = torch.tensor(inputs.pop('label')).unsqueeze(1)
return super().compute_loss(model, inputs, return_outputs=return_outputs)

def prediction_step(
self,
model: torch.nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
if 'label' in inputs:
inputs['labels'] = torch.tensor(inputs.pop('label')).unsqueeze(1)
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
15 changes: 11 additions & 4 deletions tests/train/test_cls.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from swift.llm import TrainArguments, sft_main

from swift.llm import sft_main, TrainArguments

def test_llm():
sft_main(
TrainArguments(
model='Qwen/Qwen2.5-7B-Instruct',
num_labels=2,
# dataset=['simpleai/HC3-Chinese:baike_cls#1000', ],
dataset=['simpleai/HC3:finance_cls#1000'],
use_chat_template=False))


if __name__ == '__main__':
sft_main(TrainArguments(model='Qwen/Qwen2.5-7B-Instruct',
num_labels=2,
dataset='simpleai/HC3-Chinese:baike_cls#1000'))
test_llm()

0 comments on commit 288df5b

Please sign in to comment.