Skip to content

Commit

Permalink
support seq_cls
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 24, 2024
1 parent 00c2eaa commit dfb18cf
Show file tree
Hide file tree
Showing 12 changed files with 49 additions and 64 deletions.
5 changes: 0 additions & 5 deletions docs/source/Customization/插件化.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/ca

callback会在trainer构造前注册进trainer中,example中给出了一个简单版本的EarlyStop方案。

## 定制化trainer

example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/custom_trainer.py).

用户可以在这里继承现有trainer,并实现自己的训练逻辑,例如定制data_loader、定制compute_loss等。example中给出了一个text-classification任务的trainer。

## 定制化loss

Expand Down
1 change: 0 additions & 1 deletion docs/source/Instruction/ReleaseNote3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
- 采用messages格式作为入参接口
4. 支持了plugin机制,用于定制训练过程,目前支持的plugin有:
- callback 定制训练回调方法
- custom_trainer 定制trainer
- loss 定制loss方法
- loss_scale 定制每个token的权重
- metric 定制交叉验证的指标
Expand Down
6 changes: 0 additions & 6 deletions docs/source_en/Customization/Pluginization.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,6 @@ Examples can be found [here](https://github.com/modelscope/swift/blob/main/swift

Callbacks are registered into the trainer before constructing the trainer. The example provides a simple version of the EarlyStop scheme.

## Customized Trainer

Examples can be found [here](https://github.com/modelscope/swift/blob/main/swift/plugin/custom_trainer.py).

Users can inherit existing trainers and implement their own training logic here, such as customizing data loaders, customizing compute_loss, etc. The example demonstrates a trainer for a text-classification task.

## Customized Loss

Examples can be found [here](https://github.com/modelscope/swift/blob/main/swift/plugin/loss.py).
Expand Down
1 change: 0 additions & 1 deletion docs/source_en/Instruction/ReleaseNote3.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

4. Supported plugin mechanism for customizing the training process. Current plugins include:
- callback to customize training callbacks,
- custom_trainer to customize the trainer,
- loss to customize the loss method,
- loss_scale to customize the weight of each token,
- metric to customize cross-validation metrics,
Expand Down
5 changes: 3 additions & 2 deletions swift/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
SwiftTuners, LongLoRAConfig, LongLoRA, LongLoRAModelType, SCETuning, SCETuningConfig)
from .trainers import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy,
SchedulerType, ShardedDDPOption, TrainingArguments, Seq2SeqTrainingArguments, Trainer,
Seq2SeqTrainer)
Seq2SeqTrainer, SequenceClassificationTrainer)
from .utils import get_logger
else:
_import_structure = {
Expand All @@ -30,7 +30,8 @@
],
'trainers': [
'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend', 'HubStrategy', 'IntervalStrategy', 'SchedulerType',
'ShardedDDPOption', 'TrainingArguments', 'Seq2SeqTrainingArguments', 'Trainer', 'Seq2SeqTrainer'
'ShardedDDPOption', 'TrainingArguments', 'Seq2SeqTrainingArguments', 'Trainer', 'Seq2SeqTrainer',
'SequenceClassificationTrainer'
],
'utils': ['get_logger']
}
Expand Down
6 changes: 4 additions & 2 deletions swift/llm/dataset/dataset/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ class HC3Preprocessor(ResponsePreprocessor):
def preprocess(self, row):
rows = []
for response in ['Human', 'ChatGPT']:
query = self.prompt.format(question=row['query'], answer=row[f'{response.lower()}_answers'])
query = self.prompt.format(
question=row['query'], answer=self.random_state.choice(row[f'{response.lower()}_answers']))
rows.append(super().preprocess({'query': query, 'response': response}))
return rows

Expand All @@ -368,7 +369,8 @@ class HC3ClsPreprocessor(HC3Preprocessor):
def preprocess(self, row):
rows = []
for i, response in enumerate(['Human', 'ChatGPT']):
query = self.prompt.format(question=row['query'], answer=row[f'{response.lower()}_answers'])
query = self.prompt.format(
question=row['query'], answer=self.random_state.choice(row[f'{response.lower()}_answers']))
rows.append(ResponsePreprocessor.preprocess(self, {'query': query, 'label': i}))
return rows

Expand Down
2 changes: 0 additions & 2 deletions swift/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

if TYPE_CHECKING:
from .callback import extra_callbacks
from .custom_trainer import custom_trainer_class
from .loss import LOSS_MAPPING, get_loss_func
from .loss_scale import loss_scale_map
from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric
Expand All @@ -17,7 +16,6 @@
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
_import_structure = {
'callback': ['extra_callbacks'],
'custom_trainer': ['custom_trainer_class'],
'loss': ['LOSS_MAPPING', 'get_loss_func'],
'loss_scale': ['loss_scale_map'],
'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric'],
Expand Down
37 changes: 0 additions & 37 deletions swift/plugin/custom_trainer.py

This file was deleted.

4 changes: 2 additions & 2 deletions swift/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .rlhf_trainer import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer,
RewardTrainer)
from .trainer_factory import TrainerFactory
from .trainers import Seq2SeqTrainer, Trainer
from .trainers import Seq2SeqTrainer, Trainer, SequenceClassificationTrainer
from .mixin import SwiftMixin

else:
Expand All @@ -33,7 +33,7 @@
'rlhf_trainer':
['CPOTrainer', 'DPOTrainer', 'KTOTrainer', 'ORPOTrainer', 'RLHFTrainerMixin', 'PPOTrainer', 'RewardTrainer'],
'trainer_factory': ['TrainerFactory'],
'trainers': ['Seq2SeqTrainer', 'Trainer'],
'trainers': ['Seq2SeqTrainer', 'Trainer', 'SequenceClassificationTrainer'],
'mixin': ['SwiftMixin'],
}

Expand Down
13 changes: 7 additions & 6 deletions swift/trainers/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
from dataclasses import asdict
from typing import Dict

from swift.plugin import custom_trainer_class
from swift.utils import get_logger

logger = get_logger()


class TrainerFactory:
TRAINER_MAPPING = {
'train': 'swift.trainers.Seq2SeqTrainer',
'seq2seq': 'swift.trainers.Seq2SeqTrainer',
'seq_cls': 'swift.trainers.SequenceClassificationTrainer',
'dpo': 'swift.trainers.DPOTrainer',
'orpo': 'swift.trainers.ORPOTrainer',
'kto': 'swift.trainers.KTOTrainer',
Expand All @@ -22,7 +22,8 @@ class TrainerFactory:
}

TRAINING_ARGS_MAPPING = {
'train': 'swift.trainers.Seq2SeqTrainingArguments',
'seq2seq': 'swift.trainers.Seq2SeqTrainingArguments',
'seq_cls': 'swift.trainers.TrainingArguments',
'dpo': 'swift.trainers.DPOConfig',
'orpo': 'swift.trainers.ORPOConfig',
'kto': 'swift.trainers.KTOConfig',
Expand All @@ -31,14 +32,14 @@ class TrainerFactory:
'ppo': 'swift.trainers.PPOConfig',
}

custom_trainer_class(TRAINER_MAPPING, TRAINING_ARGS_MAPPING)

@staticmethod
def get_cls(args, mapping: Dict[str, str]):
if hasattr(args, 'rlhf_type'):
train_method = args.rlhf_type
elif args.num_labels is None:
train_method = 'seq2seq'
else:
train_method = 'train'
train_method = 'seq_cls'
module_path, class_name = mapping[train_method].rsplit('.', 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
Expand Down
25 changes: 25 additions & 0 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,28 @@ def _compute_token_acc(self, outputs, labels) -> None:
if k not in self._custom_metrics:
self._custom_metrics[k] = MeanMetric(nan_value=None)
self._custom_metrics[k].update(v)


class SequenceClassificationTrainer(Trainer):
"""A trainer for text-classification task"""

# def __init__(self, *args, **kwargs):
# super().__init__(*args, **kwargs)
# self.label_names = ['labels']

def compute_loss(self, model, inputs, return_outputs=None, **kwargs):
if 'label' in inputs:
inputs['labels'] = torch.tensor(inputs.pop('label')).unsqueeze(1)
return super().compute_loss(model, inputs, return_outputs=return_outputs)

def prediction_step(
self,
model: torch.nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
**gen_kwargs,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
if 'label' in inputs:
inputs['labels'] = torch.tensor(inputs.pop('label')).unsqueeze(1)
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
8 changes: 8 additions & 0 deletions tests/train/test_cls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

from swift.llm import sft_main, TrainArguments


if __name__ == '__main__':
sft_main(TrainArguments(model='Qwen/Qwen2.5-7B-Instruct',
num_labels=2,
dataset='simpleai/HC3-Chinese:baike_cls#1000'))

0 comments on commit dfb18cf

Please sign in to comment.