diff --git "a/docs/source/Customization/\346\217\222\344\273\266\345\214\226.md" "b/docs/source/Customization/\346\217\222\344\273\266\345\214\226.md" index fc86191d2..8cb3542cb 100644 --- "a/docs/source/Customization/\346\217\222\344\273\266\345\214\226.md" +++ "b/docs/source/Customization/\346\217\222\344\273\266\345\214\226.md" @@ -8,11 +8,6 @@ example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/ca callback会在trainer构造前注册进trainer中,example中给出了一个简单版本的EarlyStop方案。 -## 定制化trainer - -example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/custom_trainer.py). - -用户可以在这里继承现有trainer,并实现自己的训练逻辑,例如定制data_loader、定制compute_loss等。example中给出了一个text-classification任务的trainer。 ## 定制化loss diff --git a/docs/source/Instruction/ReleaseNote3.0.md b/docs/source/Instruction/ReleaseNote3.0.md index 4e9bbe743..feda24658 100644 --- a/docs/source/Instruction/ReleaseNote3.0.md +++ b/docs/source/Instruction/ReleaseNote3.0.md @@ -18,7 +18,6 @@ - 采用messages格式作为入参接口 4. 支持了plugin机制,用于定制训练过程,目前支持的plugin有: - callback 定制训练回调方法 - - custom_trainer 定制trainer - loss 定制loss方法 - loss_scale 定制每个token的权重 - metric 定制交叉验证的指标 diff --git a/docs/source_en/Customization/Pluginization.md b/docs/source_en/Customization/Pluginization.md index 095e2963b..335bf456e 100644 --- a/docs/source_en/Customization/Pluginization.md +++ b/docs/source_en/Customization/Pluginization.md @@ -8,12 +8,6 @@ Examples can be found [here](https://github.com/modelscope/swift/blob/main/swift Callbacks are registered into the trainer before constructing the trainer. The example provides a simple version of the EarlyStop scheme. -## Customized Trainer - -Examples can be found [here](https://github.com/modelscope/swift/blob/main/swift/plugin/custom_trainer.py). - -Users can inherit existing trainers and implement their own training logic here, such as customizing data loaders, customizing compute_loss, etc. The example demonstrates a trainer for a text-classification task. - ## Customized Loss Examples can be found [here](https://github.com/modelscope/swift/blob/main/swift/plugin/loss.py). diff --git a/docs/source_en/Instruction/ReleaseNote3.0.md b/docs/source_en/Instruction/ReleaseNote3.0.md index b3c63fd80..c6c1c9cec 100644 --- a/docs/source_en/Instruction/ReleaseNote3.0.md +++ b/docs/source_en/Instruction/ReleaseNote3.0.md @@ -21,7 +21,6 @@ 4. Supported plugin mechanism for customizing the training process. Current plugins include: - callback to customize training callbacks, - - custom_trainer to customize the trainer, - loss to customize the loss method, - loss_scale to customize the weight of each token, - metric to customize cross-validation metrics, diff --git a/swift/__init__.py b/swift/__init__.py index 31e6b0f04..111536dbc 100644 --- a/swift/__init__.py +++ b/swift/__init__.py @@ -14,7 +14,7 @@ SwiftTuners, LongLoRAConfig, LongLoRA, LongLoRAModelType, SCETuning, SCETuningConfig) from .trainers import (EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, IntervalStrategy, SchedulerType, ShardedDDPOption, TrainingArguments, Seq2SeqTrainingArguments, Trainer, - Seq2SeqTrainer) + Seq2SeqTrainer, SequenceClassificationTrainer) from .utils import get_logger else: _import_structure = { @@ -30,7 +30,8 @@ ], 'trainers': [ 'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend', 'HubStrategy', 'IntervalStrategy', 'SchedulerType', - 'ShardedDDPOption', 'TrainingArguments', 'Seq2SeqTrainingArguments', 'Trainer', 'Seq2SeqTrainer' + 'ShardedDDPOption', 'TrainingArguments', 'Seq2SeqTrainingArguments', 'Trainer', 'Seq2SeqTrainer', + 'SequenceClassificationTrainer' ], 'utils': ['get_logger'] } diff --git a/swift/llm/dataset/dataset/llm.py b/swift/llm/dataset/dataset/llm.py index 80ea28749..adccafdef 100644 --- a/swift/llm/dataset/dataset/llm.py +++ b/swift/llm/dataset/dataset/llm.py @@ -358,7 +358,8 @@ class HC3Preprocessor(ResponsePreprocessor): def preprocess(self, row): rows = [] for response in ['Human', 'ChatGPT']: - query = self.prompt.format(question=row['query'], answer=row[f'{response.lower()}_answers']) + query = self.prompt.format( + question=row['query'], answer=self.random_state.choice(row[f'{response.lower()}_answers'])) rows.append(super().preprocess({'query': query, 'response': response})) return rows @@ -368,7 +369,8 @@ class HC3ClsPreprocessor(HC3Preprocessor): def preprocess(self, row): rows = [] for i, response in enumerate(['Human', 'ChatGPT']): - query = self.prompt.format(question=row['query'], answer=row[f'{response.lower()}_answers']) + query = self.prompt.format( + question=row['query'], answer=self.random_state.choice(row[f'{response.lower()}_answers'])) rows.append(ResponsePreprocessor.preprocess(self, {'query': query, 'label': i})) return rows diff --git a/swift/plugin/__init__.py b/swift/plugin/__init__.py index eb9e19bd0..f67289bfd 100644 --- a/swift/plugin/__init__.py +++ b/swift/plugin/__init__.py @@ -5,7 +5,6 @@ if TYPE_CHECKING: from .callback import extra_callbacks - from .custom_trainer import custom_trainer_class from .loss import LOSS_MAPPING, get_loss_func from .loss_scale import loss_scale_map from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric @@ -17,7 +16,6 @@ _extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')} _import_structure = { 'callback': ['extra_callbacks'], - 'custom_trainer': ['custom_trainer_class'], 'loss': ['LOSS_MAPPING', 'get_loss_func'], 'loss_scale': ['loss_scale_map'], 'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric'], diff --git a/swift/plugin/custom_trainer.py b/swift/plugin/custom_trainer.py deleted file mode 100644 index 45f225481..000000000 --- a/swift/plugin/custom_trainer.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch - -from swift.trainers import Trainer - - -class SequenceClassificationTrainer(Trainer): - """A trainer for text-classification task""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.label_names = ['labels'] - - def compute_loss(self, model, inputs, return_outputs=None, **kwargs): - if 'label' in inputs: - inputs['labels'] = torch.tensor(inputs.pop('label')).unsqueeze(1) - return super().compute_loss(model, inputs, return_outputs=return_outputs) - - def prediction_step( - self, - model: torch.nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - prediction_loss_only: bool, - ignore_keys: Optional[List[str]] = None, - **gen_kwargs, - ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: - if 'label' in inputs: - inputs['labels'] = torch.tensor(inputs.pop('label')).unsqueeze(1) - return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs) - - -# To train sequence classification tasks, uncomment this. -def custom_trainer_class(trainer_mapping, training_args_mapping): - # trainer_mapping['train'] = 'swift.plugin.custom_trainer.SequenceClassificationTrainer' - pass diff --git a/swift/trainers/__init__.py b/swift/trainers/__init__.py index 2e57e64de..f3126d881 100644 --- a/swift/trainers/__init__.py +++ b/swift/trainers/__init__.py @@ -20,7 +20,7 @@ from .rlhf_trainer import (CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RLHFTrainerMixin, PPOTrainer, RewardTrainer) from .trainer_factory import TrainerFactory - from .trainers import Seq2SeqTrainer, Trainer + from .trainers import Seq2SeqTrainer, Trainer, SequenceClassificationTrainer from .mixin import SwiftMixin else: @@ -33,7 +33,7 @@ 'rlhf_trainer': ['CPOTrainer', 'DPOTrainer', 'KTOTrainer', 'ORPOTrainer', 'RLHFTrainerMixin', 'PPOTrainer', 'RewardTrainer'], 'trainer_factory': ['TrainerFactory'], - 'trainers': ['Seq2SeqTrainer', 'Trainer'], + 'trainers': ['Seq2SeqTrainer', 'Trainer', 'SequenceClassificationTrainer'], 'mixin': ['SwiftMixin'], } diff --git a/swift/trainers/trainer_factory.py b/swift/trainers/trainer_factory.py index 795f56d61..3cae2f97b 100644 --- a/swift/trainers/trainer_factory.py +++ b/swift/trainers/trainer_factory.py @@ -4,7 +4,6 @@ from dataclasses import asdict from typing import Dict -from swift.plugin import custom_trainer_class from swift.utils import get_logger logger = get_logger() @@ -12,7 +11,8 @@ class TrainerFactory: TRAINER_MAPPING = { - 'train': 'swift.trainers.Seq2SeqTrainer', + 'seq2seq': 'swift.trainers.Seq2SeqTrainer', + 'seq_cls': 'swift.trainers.SequenceClassificationTrainer', 'dpo': 'swift.trainers.DPOTrainer', 'orpo': 'swift.trainers.ORPOTrainer', 'kto': 'swift.trainers.KTOTrainer', @@ -22,7 +22,8 @@ class TrainerFactory: } TRAINING_ARGS_MAPPING = { - 'train': 'swift.trainers.Seq2SeqTrainingArguments', + 'seq2seq': 'swift.trainers.Seq2SeqTrainingArguments', + 'seq_cls': 'swift.trainers.TrainingArguments', 'dpo': 'swift.trainers.DPOConfig', 'orpo': 'swift.trainers.ORPOConfig', 'kto': 'swift.trainers.KTOConfig', @@ -31,14 +32,14 @@ class TrainerFactory: 'ppo': 'swift.trainers.PPOConfig', } - custom_trainer_class(TRAINER_MAPPING, TRAINING_ARGS_MAPPING) - @staticmethod def get_cls(args, mapping: Dict[str, str]): if hasattr(args, 'rlhf_type'): train_method = args.rlhf_type + elif args.num_labels is None: + train_method = 'seq2seq' else: - train_method = 'train' + train_method = 'seq_cls' module_path, class_name = mapping[train_method].rsplit('.', 1) module = importlib.import_module(module_path) return getattr(module, class_name) diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index 84ed05328..2e261b201 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -165,3 +165,28 @@ def _compute_token_acc(self, outputs, labels) -> None: if k not in self._custom_metrics: self._custom_metrics[k] = MeanMetric(nan_value=None) self._custom_metrics[k].update(v) + + +class SequenceClassificationTrainer(Trainer): + """A trainer for text-classification task""" + + # def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) + # self.label_names = ['labels'] + + def compute_loss(self, model, inputs, return_outputs=None, **kwargs): + if 'label' in inputs: + inputs['labels'] = torch.tensor(inputs.pop('label')).unsqueeze(1) + return super().compute_loss(model, inputs, return_outputs=return_outputs) + + def prediction_step( + self, + model: torch.nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + **gen_kwargs, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + if 'label' in inputs: + inputs['labels'] = torch.tensor(inputs.pop('label')).unsqueeze(1) + return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs) diff --git a/tests/train/test_cls.py b/tests/train/test_cls.py new file mode 100644 index 000000000..4cbd1d89d --- /dev/null +++ b/tests/train/test_cls.py @@ -0,0 +1,8 @@ + +from swift.llm import sft_main, TrainArguments + + +if __name__ == '__main__': + sft_main(TrainArguments(model='Qwen/Qwen2.5-7B-Instruct', + num_labels=2, + dataset='simpleai/HC3-Chinese:baike_cls#1000'))