Skip to content

Commit

Permalink
support cls
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Dec 24, 2024
1 parent 97e39da commit 4aa239a
Show file tree
Hide file tree
Showing 11 changed files with 62 additions and 47 deletions.
13 changes: 12 additions & 1 deletion swift/llm/argument/base_args/base_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,12 @@ class BaseArguments(CompatArguments, GenerationArguments, QuantizeArguments, Dat
load_data_args (bool): Flag to determine if dataset configuration should be loaded. Default is False.
use_hf (bool): Flag to determine if Hugging Face should be used. Default is False.
hub_token (Optional[str]): SDK token for authentication. Default is None.
num_labels (Optional[int]): Number of labels for classification tasks. Default is None.
custom_register_path (List[str]): Path to custom .py file for dataset registration. Default is None.
ignore_args_error (bool): Flag to ignore argument errors for notebook compatibility. Default is False.
use_swift_lora (bool): Use swift lora, a compatible argument
"""
task_type: Literal['causal_lm', 'seq_cls'] = None
tuner_backend: Literal['peft', 'unsloth'] = 'peft'
train_type: str = field(default='lora', metadata={'help': f'train_type choices: {list(get_supported_tuners())}'})
adapters: List[str] = field(default_factory=list)
Expand Down Expand Up @@ -112,6 +114,14 @@ def _init_custom_register(self) -> None:
__import__(fname.rstrip('.py'))
logger.info(f'Successfully registered `{self.custom_register_path}`')

def _init_task_type(self):
if self.task_type is not None:
return
if self.num_labels is None:
self.task_type = 'causal_lm'
else:
self.task_type = 'seq_cls'

def _init_adapters(self):
if isinstance(self.adapters, str):
self.adapters = [self.adapters]
Expand All @@ -124,6 +134,7 @@ def __post_init__(self):
self.use_hf = True
os.environ['USE_HF'] = '1'
CompatArguments.__post_init__(self)
self._init_task_type()
self._init_adapters()
self._init_ckpt_dir()
self._init_custom_register()
Expand Down Expand Up @@ -252,7 +263,7 @@ def get_model_processor(self, *, model=None, model_type=None, model_revision=Non
kwargs['model_revision'] = model_revision or self.model_revision

model_kwargs = {}
if self.num_labels is not None:
if self.task_type == 'seq_cls':
from transformers import AutoModelForSequenceClassification
kwargs['automodel_class'] = AutoModelForSequenceClassification
model_kwargs = {'num_labels': self.num_labels}
Expand Down
1 change: 0 additions & 1 deletion swift/llm/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ class TrainArguments(TorchAccArguments, TunerArguments, Seq2SeqTrainingOverrideA
resume_only_model (bool): Flag to resume training only the model. Default is False.
check_model (bool): Flag to check the model is latest. Default is True.
loss_type (Optional[str]): Type of loss function to use. Default is None.
num_labels (Optional[int]): Number of labels for classification tasks. Default is None.
packing (bool): Flag to enable packing of datasets. Default is False.
lazy_tokenize (Optional[bool]): Flag to enable lazy tokenization. Default is None.
acc_strategy (Literal['token', 'seq']): Strategy for accumulation. Default is 'token'.
Expand Down
5 changes: 3 additions & 2 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def get_model_tokenizer_from_local(model_dir: str,
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)

if model_kwargs.get('num_labels') is not None:
model_config.num_labels = model_kwargs.pop('num_labels')
num_labels = model_kwargs.pop('num_labels', None)
if num_labels:
model_config.num_labels = num_labels

model = None
if load_model:
Expand Down
2 changes: 2 additions & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,8 @@ def is_training(self):
return self.mode not in {'vllm', 'lmdeploy', 'pt'}

def set_mode(self, mode: Literal['vllm', 'lmdeploy', 'pt', 'seq_cls', 'train', 'rlhf', 'kto']) -> None:
if mode == 'causal_lm':
mode = 'train'
self.mode = mode

def register_post_encode_hook(self, models: List[nn.Module]) -> None:
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/train/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def _prepare_template(self, use_chat_template: Optional[bool] = None) -> None:
logger.info(f'default_system: {template.template_meta.default_system}')
if template.use_model:
template.model = self.model
template.set_mode('train' if args.num_labels is None else 'seq_cls')
template.set_mode(args.task_type)
self.template = template

def _get_dataset(self):
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/train/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def prepare_adapter(args: TrainArguments, model, *, template=None, train_dataset
'lorap_lr_ratio': args.lorap_lr_ratio,
'init_lora_weights': args.init_weights,
}
task_type = 'CAUSAL_LM' if args.num_labels is None else 'SEQ_CLS'
task_type = args.task_type.upper()
if args.train_type in ('lora', 'longlora'):
if args.use_swift_lora:
lora_config = LoRAConfig(lora_dtype=args.lora_dtype, **lora_kwargs)
Expand Down
14 changes: 5 additions & 9 deletions swift/plugin/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,24 +132,20 @@ def compute_acc(preds,
if isinstance(preds, torch.Tensor):
preds = preds.cpu().numpy()
labels = labels.cpu().numpy()

if is_encoder_decoder:
labels = labels[..., :]
preds = preds[..., :]
else:
if preds.ndim >= 2 and not is_encoder_decoder:
labels = labels[..., 1:]
preds = preds[..., :-1]
if preds.shape != labels.shape:
return {}

masks = labels != -100
if acc_strategy == 'seq':
if acc_strategy == 'token' or preds.ndim == 1:
acc_list = (preds[masks] == labels[masks]).tolist()
else:
acc_list = []
for i, m in enumerate(masks):
acc_list.append(np.all(preds[i, m] == labels[i, m]))
else:
acc_list = (preds[masks] == labels[masks]).tolist()
return {f'{acc_strategy}_acc': acc_list}
return {f'{acc_strategy}_acc' if preds.ndim >= 2 else 'acc': acc_list}


def compute_acc_metrics(eval_prediction: EvalPrediction,
Expand Down
20 changes: 19 additions & 1 deletion swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
import time
from contextlib import contextmanager
from copy import copy
from swift.plugin import MeanMetric, compute_acc
from functools import wraps
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
from swift.utils.torchacc_utils import ta_trim_graph

import safetensors
import torch
Expand All @@ -31,7 +33,7 @@
from swift.llm import Template
from swift.plugin import extra_tuners
from swift.tuners import SwiftModel
from swift.utils import get_logger, is_mp_ddp
from swift.utils import get_logger, is_mp_ddp, use_torchacc
from .arguments import TrainingArguments
from .optimizers.galore import create_optimizer_and_scheduler
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model
Expand Down Expand Up @@ -377,3 +379,19 @@ def get_train_dataloader(self):
else:
from swift.trainers.xtuner import get_xtuner_train_dataloader
return get_xtuner_train_dataloader(self)

def _compute_acc(self, outputs, labels) -> None:
args = self.args
acc_steps = args.acc_steps
preds = outputs.logits.argmax(dim=-1)
if self.state.global_step % acc_steps == 0:
if use_torchacc():
ta_trim_graph()
preds = preds.to('cpu')
labels = labels.to('cpu')
metrics = compute_acc(
preds, labels, acc_strategy=args.acc_strategy, is_encoder_decoder=args.is_encoder_decoder)
for k, v in metrics.items():
if k not in self._custom_metrics:
self._custom_metrics[k] = MeanMetric(nan_value=None)
self._custom_metrics[k].update(v)
2 changes: 1 addition & 1 deletion swift/trainers/rlhf_trainer/rlhf_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def get_batch_logps(self, logits: torch.FloatTensor, labels: torch.LongTensor, *
labels = labels.clone() # fix trl bug
return super().get_batch_logps(logits, labels, *args, **kwargs)

def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=None):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
res = super().compute_loss(model, inputs, return_outputs=return_outputs)
# compat transformers>=4.46.*
if num_items_in_batch is not None:
Expand Down
8 changes: 3 additions & 5 deletions swift/trainers/trainer_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class TrainerFactory:
TRAINER_MAPPING = {
'seq2seq': 'swift.trainers.Seq2SeqTrainer',
'causal_lm': 'swift.trainers.Seq2SeqTrainer',
'seq_cls': 'swift.trainers.Trainer',
'dpo': 'swift.trainers.DPOTrainer',
'orpo': 'swift.trainers.ORPOTrainer',
Expand All @@ -22,7 +22,7 @@ class TrainerFactory:
}

TRAINING_ARGS_MAPPING = {
'seq2seq': 'swift.trainers.Seq2SeqTrainingArguments',
'causal_lm': 'swift.trainers.Seq2SeqTrainingArguments',
'seq_cls': 'swift.trainers.TrainingArguments',
'dpo': 'swift.trainers.DPOConfig',
'orpo': 'swift.trainers.ORPOConfig',
Expand All @@ -36,10 +36,8 @@ class TrainerFactory:
def get_cls(args, mapping: Dict[str, str]):
if hasattr(args, 'rlhf_type'):
train_method = args.rlhf_type
elif args.num_labels is None:
train_method = 'seq2seq'
else:
train_method = 'seq_cls'
train_method = args.task_type
module_path, class_name = mapping[train_method].rsplit('.', 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)
Expand Down
40 changes: 15 additions & 25 deletions swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,22 @@
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.utils import is_peft_available

from swift.plugin import MeanMetric, compute_acc
from swift.utils import JsonlWriter, Serializer, use_torchacc
from swift.utils.torchacc_utils import ta_trim_graph
from .arguments import Seq2SeqTrainingArguments
from swift.utils import JsonlWriter, Serializer
from .arguments import TrainingArguments, Seq2SeqTrainingArguments
from .mixin import SwiftMixin
from .torchacc_mixin import TorchAccMixin


class Trainer(SwiftMixin, HfTrainer):
pass
args: TrainingArguments
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
kwargs = {}
if num_items_in_batch is not None:
kwargs['num_items_in_batch'] = num_items_in_batch
loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs)
if inputs.get('labels') is not None:
self._compute_acc(outputs, inputs['labels'])
return (loss, outputs) if return_outputs else loss


class Seq2SeqTrainer(TorchAccMixin, SwiftMixin, HfSeq2SeqTrainer):
Expand Down Expand Up @@ -95,7 +101,7 @@ def prediction_step(
labels_list = pad_sequence(labels_list, batch_first=True, padding_value=0)
return None, response_list, labels_list

def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=None):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
loss_kwargs = {}
labels = None
if (self.label_smoother is not None or self.compute_loss_func is not None) and 'labels' in inputs:
Expand Down Expand Up @@ -146,23 +152,7 @@ def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=No
if getattr(self.args, 'average_tokens_across_devices', False):
loss *= self.accelerator.num_processes

if outputs.logits is not None:
# In case of Liger
self._compute_token_acc(outputs, labels)
if outputs.logits is not None and labels is not None:
# Liger does not have logits
self._compute_acc(outputs, labels)
return (loss, outputs) if return_outputs else loss

def _compute_token_acc(self, outputs, labels) -> None:

acc_steps = self.args.acc_steps
preds = outputs.logits.argmax(dim=2)
if self.state.global_step % acc_steps == 0:
if use_torchacc():
ta_trim_graph()
preds = preds.to('cpu')
labels = labels.to('cpu')
metrics = compute_acc(
preds, labels, acc_strategy=self.args.acc_strategy, is_encoder_decoder=self.args.is_encoder_decoder)
for k, v in metrics.items():
if k not in self._custom_metrics:
self._custom_metrics[k] = MeanMetric(nan_value=None)
self._custom_metrics[k].update(v)

0 comments on commit 4aa239a

Please sign in to comment.