From 4aa239a84ef88bfc52211000bba64931fea8daf0 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Tue, 24 Dec 2024 15:42:35 +0800 Subject: [PATCH] support cls --- swift/llm/argument/base_args/base_args.py | 13 +++++++- swift/llm/argument/train_args.py | 1 - swift/llm/model/register.py | 5 +-- swift/llm/template/base.py | 2 ++ swift/llm/train/sft.py | 2 +- swift/llm/train/tuner.py | 2 +- swift/plugin/metric.py | 14 +++----- swift/trainers/mixin.py | 20 +++++++++++- swift/trainers/rlhf_trainer/rlhf_mixin.py | 2 +- swift/trainers/trainer_factory.py | 8 ++--- swift/trainers/trainers.py | 40 +++++++++-------------- 11 files changed, 62 insertions(+), 47 deletions(-) diff --git a/swift/llm/argument/base_args/base_args.py b/swift/llm/argument/base_args/base_args.py index 529053cad..d1039f997 100644 --- a/swift/llm/argument/base_args/base_args.py +++ b/swift/llm/argument/base_args/base_args.py @@ -77,10 +77,12 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat load_data_args (bool): Flag to determine if dataset configuration should be loaded. Default is False. use_hf (bool): Flag to determine if Hugging Face should be used. Default is False. hub_token (Optional[str]): SDK token for authentication. Default is None. + num_labels (Optional[int]): Number of labels for classification tasks. Default is None. custom_register_path (List[str]): Path to custom .py file for dataset registration. Default is None. ignore_args_error (bool): Flag to ignore argument errors for notebook compatibility. Default is False. use_swift_lora (bool): Use swift lora, a compatible argument """ + task_type: Literal['causal_lm', 'seq_cls'] = None tuner_backend: Literal['peft', 'unsloth'] = 'peft' train_type: str = field(default='lora', metadata={'help': f'train_type choices: {list(get_supported_tuners())}'}) adapters: List[str] = field(default_factory=list) @@ -112,6 +114,14 @@ def _init_custom_register(self) -> None: __import__(fname.rstrip('.py')) logger.info(f'Successfully registered `{self.custom_register_path}`') + def _init_task_type(self): + if self.task_type is not None: + return + if self.num_labels is None: + self.task_type = 'causal_lm' + else: + self.task_type = 'seq_cls' + def _init_adapters(self): if isinstance(self.adapters, str): self.adapters = [self.adapters] @@ -124,6 +134,7 @@ def __post_init__(self): self.use_hf = True os.environ['USE_HF'] = '1' CompatArguments.__post_init__(self) + self._init_task_type() self._init_adapters() self._init_ckpt_dir() self._init_custom_register() @@ -252,7 +263,7 @@ def get_model_processor(self, *, model=None, model_type=None, model_revision=Non kwargs['model_revision'] = model_revision or self.model_revision model_kwargs = {} - if self.num_labels is not None: + if self.task_type == 'seq_cls': from transformers import AutoModelForSequenceClassification kwargs['automodel_class'] = AutoModelForSequenceClassification model_kwargs = {'num_labels': self.num_labels} diff --git a/swift/llm/argument/train_args.py b/swift/llm/argument/train_args.py index 165abe231..1d2566930 100644 --- a/swift/llm/argument/train_args.py +++ b/swift/llm/argument/train_args.py @@ -97,7 +97,6 @@ class TrainArguments(TorchAccArguments, TunerArguments, Seq2SeqTrainingOverrideA resume_only_model (bool): Flag to resume training only the model. Default is False. check_model (bool): Flag to check the model is latest. Default is True. loss_type (Optional[str]): Type of loss function to use. Default is None. - num_labels (Optional[int]): Number of labels for classification tasks. Default is None. packing (bool): Flag to enable packing of datasets. Default is False. lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None. acc_strategy (Literal['token', 'seq']): Strategy for accumulation. Default is 'token'. diff --git a/swift/llm/model/register.py b/swift/llm/model/register.py index e81716c56..ca5003ce3 100644 --- a/swift/llm/model/register.py +++ b/swift/llm/model/register.py @@ -176,8 +176,9 @@ def get_model_tokenizer_from_local(model_dir: str, if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) - if model_kwargs.get('num_labels') is not None: - model_config.num_labels = model_kwargs.pop('num_labels') + num_labels = model_kwargs.pop('num_labels', None) + if num_labels: + model_config.num_labels = num_labels model = None if load_model: diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index f939ff916..16914a59b 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -697,6 +697,8 @@ def is_training(self): return self.mode not in {'vllm', 'lmdeploy', 'pt'} def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None: + if mode == 'causal_lm': + mode = 'train' self.mode = mode def register_post_encode_hook(self, models: List[nn.Module]) -> None: diff --git a/swift/llm/train/sft.py b/swift/llm/train/sft.py index a159288c3..979e2c072 100644 --- a/swift/llm/train/sft.py +++ b/swift/llm/train/sft.py @@ -80,7 +80,7 @@ def _prepare_template(self, use_chat_template: Optional[bool] = None) -> None: logger.info(f'default_system: {template.template_meta.default_system}') if template.use_model: template.model = self.model - template.set_mode('train' if args.num_labels is None else 'seq_cls') + template.set_mode(args.task_type) self.template = template def _get_dataset(self): diff --git a/swift/llm/train/tuner.py b/swift/llm/train/tuner.py index 20274161d..94a62b61a 100644 --- a/swift/llm/train/tuner.py +++ b/swift/llm/train/tuner.py @@ -153,7 +153,7 @@ def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset 'lorap_lr_ratio': args.lorap_lr_ratio, 'init_lora_weights': args.init_weights, } - task_type = 'CAUSAL_LM' if args.num_labels is None else 'SEQ_CLS' + task_type = args.task_type.upper() if args.train_type in ('lora', 'longlora'): if args.use_swift_lora: lora_config = LoRAConfig(lora_dtype=args.lora_dtype, **lora_kwargs) diff --git a/swift/plugin/metric.py b/swift/plugin/metric.py index d942713f6..b34f19d78 100644 --- a/swift/plugin/metric.py +++ b/swift/plugin/metric.py @@ -132,24 +132,20 @@ def compute_acc(preds, if isinstance(preds, torch.Tensor): preds = preds.cpu().numpy() labels = labels.cpu().numpy() - - if is_encoder_decoder: - labels = labels[..., :] - preds = preds[..., :] - else: + if preds.ndim >= 2 and not is_encoder_decoder: labels = labels[..., 1:] preds = preds[..., :-1] if preds.shape != labels.shape: return {} masks = labels != -100 - if acc_strategy == 'seq': + if acc_strategy == 'token' or preds.ndim == 1: + acc_list = (preds[masks] == labels[masks]).tolist() + else: acc_list = [] for i, m in enumerate(masks): acc_list.append(np.all(preds[i, m] == labels[i, m])) - else: - acc_list = (preds[masks] == labels[masks]).tolist() - return {f'{acc_strategy}_acc': acc_list} + return {f'{acc_strategy}_acc' if preds.ndim >= 2 else 'acc': acc_list} def compute_acc_metrics(eval_prediction: EvalPrediction, diff --git a/swift/trainers/mixin.py b/swift/trainers/mixin.py index fae150efa..52c98093d 100644 --- a/swift/trainers/mixin.py +++ b/swift/trainers/mixin.py @@ -6,9 +6,11 @@ import time from contextlib import contextmanager from copy import copy +from swift.plugin import MeanMetric, compute_acc from functools import wraps from types import MethodType from typing import Callable, Dict, List, Optional, Tuple, Union +from swift.utils.torchacc_utils import ta_trim_graph import safetensors import torch @@ -31,7 +33,7 @@ from swift.llm import Template from swift.plugin import extra_tuners from swift.tuners import SwiftModel -from swift.utils import get_logger, is_mp_ddp +from swift.utils import get_logger, is_mp_ddp, use_torchacc from .arguments import TrainingArguments from .optimizers.galore import create_optimizer_and_scheduler from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model @@ -377,3 +379,19 @@ def get_train_dataloader(self): else: from swift.trainers.xtuner import get_xtuner_train_dataloader return get_xtuner_train_dataloader(self) + + def _compute_acc(self, outputs, labels) -> None: + args = self.args + acc_steps = args.acc_steps + preds = outputs.logits.argmax(dim=-1) + if self.state.global_step % acc_steps == 0: + if use_torchacc(): + ta_trim_graph() + preds = preds.to('cpu') + labels = labels.to('cpu') + metrics = compute_acc( + preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=args.is_encoder_decoder) + for k, v in metrics.items(): + if k not in self._custom_metrics: + self._custom_metrics[k] = MeanMetric(nan_value=None) + self._custom_metrics[k].update(v) diff --git a/swift/trainers/rlhf_trainer/rlhf_mixin.py b/swift/trainers/rlhf_trainer/rlhf_mixin.py index 3d57b83f0..708e65b46 100644 --- a/swift/trainers/rlhf_trainer/rlhf_mixin.py +++ b/swift/trainers/rlhf_trainer/rlhf_mixin.py @@ -150,7 +150,7 @@ def get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor, * labels = labels.clone() # fix trl bug return super().get_batch_logps(logits, labels, *args, **kwargs) - def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=None): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): res = super().compute_loss(model, inputs, return_outputs=return_outputs) # compat transformers>=4.46.* if num_items_in_batch is not None: diff --git a/swift/trainers/trainer_factory.py b/swift/trainers/trainer_factory.py index f81b256ee..480ca8287 100644 --- a/swift/trainers/trainer_factory.py +++ b/swift/trainers/trainer_factory.py @@ -11,7 +11,7 @@ class TrainerFactory: TRAINER_MAPPING = { - 'seq2seq': 'swift.trainers.Seq2SeqTrainer', + 'causal_lm': 'swift.trainers.Seq2SeqTrainer', 'seq_cls': 'swift.trainers.Trainer', 'dpo': 'swift.trainers.DPOTrainer', 'orpo': 'swift.trainers.ORPOTrainer', @@ -22,7 +22,7 @@ class TrainerFactory: } TRAINING_ARGS_MAPPING = { - 'seq2seq': 'swift.trainers.Seq2SeqTrainingArguments', + 'causal_lm': 'swift.trainers.Seq2SeqTrainingArguments', 'seq_cls': 'swift.trainers.TrainingArguments', 'dpo': 'swift.trainers.DPOConfig', 'orpo': 'swift.trainers.ORPOConfig', @@ -36,10 +36,8 @@ class TrainerFactory: def get_cls(args, mapping: Dict[str, str]): if hasattr(args, 'rlhf_type'): train_method = args.rlhf_type - elif args.num_labels is None: - train_method = 'seq2seq' else: - train_method = 'seq_cls' + train_method = args.task_type module_path, class_name = mapping[train_method].rsplit('.', 1) module = importlib.import_module(module_path) return getattr(module, class_name) diff --git a/swift/trainers/trainers.py b/swift/trainers/trainers.py index e09090199..045119720 100644 --- a/swift/trainers/trainers.py +++ b/swift/trainers/trainers.py @@ -13,16 +13,22 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from transformers.utils import is_peft_available -from swift.plugin import MeanMetric, compute_acc -from swift.utils import JsonlWriter, Serializer, use_torchacc -from swift.utils.torchacc_utils import ta_trim_graph -from .arguments import Seq2SeqTrainingArguments +from swift.utils import JsonlWriter, Serializer +from .arguments import TrainingArguments, Seq2SeqTrainingArguments from .mixin import SwiftMixin from .torchacc_mixin import TorchAccMixin class Trainer(SwiftMixin, HfTrainer): - pass + args: TrainingArguments + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): + kwargs = {} + if num_items_in_batch is not None: + kwargs['num_items_in_batch'] = num_items_in_batch + loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs) + if inputs.get('labels') is not None: + self._compute_acc(outputs, inputs['labels']) + return (loss, outputs) if return_outputs else loss class Seq2SeqTrainer(TorchAccMixin, SwiftMixin, HfSeq2SeqTrainer): @@ -95,7 +101,7 @@ def prediction_step( labels_list = pad_sequence(labels_list, batch_first=True, padding_value=0) return None, response_list, labels_list - def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=None): + def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): loss_kwargs = {} labels = None if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs: @@ -146,23 +152,7 @@ def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=No if getattr(self.args, 'average_tokens_across_devices', False): loss *= self.accelerator.num_processes - if outputs.logits is not None: - # In case of Liger - self._compute_token_acc(outputs, labels) + if outputs.logits is not None and labels is not None: + # Liger does not have logits + self._compute_acc(outputs, labels) return (loss, outputs) if return_outputs else loss - - def _compute_token_acc(self, outputs, labels) -> None: - - acc_steps = self.args.acc_steps - preds = outputs.logits.argmax(dim=2) - if self.state.global_step % acc_steps == 0: - if use_torchacc(): - ta_trim_graph() - preds = preds.to('cpu') - labels = labels.to('cpu') - metrics = compute_acc( - preds, labels, acc_strategy=self.args.acc_strategy, is_encoder_decoder=self.args.is_encoder_decoder) - for k, v in metrics.items(): - if k not in self._custom_metrics: - self._custom_metrics[k] = MeanMetric(nan_value=None) - self._custom_metrics[k].update(v)